GitOrigin-RevId: ca5a6ed8eb
tags/v1.2.0
| @@ -12,6 +12,7 @@ import itertools | |||
| import numpy as np | |||
| from .._imperative_rt import TensorAttr, imperative | |||
| from .._imperative_rt.core2 import apply | |||
| from ..ops.builtin import ( | |||
| Broadcast, | |||
| Elemwise, | |||
| @@ -25,37 +26,6 @@ from ..ops.builtin import ( | |||
| Subtensor, | |||
| ) | |||
| from ..ops.special import Const | |||
| from ..tensor.core import apply | |||
| from ..tensor.function import Function | |||
| @functools.singledispatch | |||
| def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): | |||
| assert 0 | |||
| @builtin_op_get_backward_fn.register(OpDef) | |||
| def _(op: OpDef, inputs, outputs, input_requires_grad): | |||
| if isinstance(op, Reshape): | |||
| grad_fn = reshape_grad_fn | |||
| elif isinstance(op, Subtensor): | |||
| grad_fn = subtensor_grad_fn | |||
| elif isinstance(op, IndexingMultiAxisVec): | |||
| grad_fn = indexingMultiAxisVec_grad_fn | |||
| elif isinstance(op, Broadcast) or ( | |||
| isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD | |||
| ): | |||
| grad_fn = elemwise_add_grad_fn | |||
| elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM: | |||
| grad_fn = reduce_sum_grad_fn | |||
| else: | |||
| grad_fn = default_grad_fn | |||
| return grad_fn(op, inputs, outputs, input_requires_grad) | |||
| @builtin_op_get_backward_fn.register(Function) | |||
| def _(op: Function, inputs, outputs, input_requires_grad): | |||
| return op.get_backward_fn(), [True,] * len(outputs) | |||
| def default_grad_fn(op, inputs, outputs, input_requires_grad): | |||
| @@ -19,8 +19,6 @@ import megengine as mge | |||
| from .._imperative_rt import core2, ops | |||
| from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||
| from ..ops.special import Const | |||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..tensor.function import Function | |||
| from . import builtin_op_utils | |||
| """ Some notes: | |||
| @@ -48,146 +46,6 @@ def get_grad_managers(): | |||
| return [_grad_manager_dict[key] for key in _grad_manager_dict] | |||
| def add(a, b): | |||
| (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b) | |||
| return c | |||
| def get_tensor(x): | |||
| # use recursion to avoid infinite loop | |||
| if isinstance(x, Tensor): | |||
| return x | |||
| try: | |||
| x = x.__wrapped__ | |||
| except AttributeError: | |||
| raise TypeError(type(x)) | |||
| return get_tensor(x) | |||
| class clearable: | |||
| __cleared = False | |||
| def __bool__(self): | |||
| return not self.__cleared | |||
| def clear(self): | |||
| self.__dict__.clear() | |||
| self.__cleared = True | |||
| class OpNode(clearable): | |||
| """ OpNode saves all the information to form the computational graph. | |||
| """ | |||
| def __init__(self): | |||
| self.id = None | |||
| self.inputs = None # Could be VariableNode | |||
| self.outputs = None # Could be VariableNode | |||
| self.backward = None | |||
| self.has_grad_fn = None | |||
| self.backward_allow_noinput = False | |||
| class VariableNode(clearable): | |||
| """ VariableNode saves OpNode and callback. | |||
| FIXME!!! Explain manager and owner | |||
| """ | |||
| def __init__(self, manager, owner, opnode=None, callback=None): | |||
| # manager is Grad type | |||
| self.manager = weakref.ref(manager) | |||
| # owner is Tensor type | |||
| self.owner = weakref.ref(owner) | |||
| self.opnode = opnode | |||
| self.callback = callback | |||
| class Tracer(clearable, TensorBase): | |||
| def __init__(self, node=None): | |||
| """ type(node) is VariableNode | |||
| """ | |||
| self.node = node | |||
| @functools.singledispatch | |||
| def check_backward_allow_noinput(op: OpDef): | |||
| return False | |||
| @functools.singledispatch | |||
| def get_op_has_grad_fn(op: OpDef): | |||
| assert 0 | |||
| @get_op_has_grad_fn.register(OpDef) | |||
| def _(op: OpDef): | |||
| return default_has_grad_fn | |||
| @get_op_has_grad_fn.register(Function) | |||
| def _(op: Function): | |||
| return default_has_grad_fn | |||
| def default_has_grad_fn(opnode, reached): | |||
| for v in opnode.outputs: | |||
| if v() in reached: | |||
| return True | |||
| return False | |||
| @apply.register() | |||
| def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||
| args = tuple(i if isinstance(i, Tracer) else None for i in args) | |||
| input_requires_grad = list(map(bool, args)) | |||
| if not any(input_requires_grad): | |||
| return | |||
| ctx = get_context() | |||
| manager = None | |||
| assert len(ctx.inputs) == len(args) | |||
| for i, j in zip(ctx.inputs, args): | |||
| if j: | |||
| j = j.node | |||
| assert i is j.owner() | |||
| if manager is None: | |||
| manager = j.manager() | |||
| assert manager | |||
| else: | |||
| assert manager is j.manager() | |||
| if not manager._enabled: | |||
| return | |||
| # register backward method | |||
| # tuple of backward functions corresponding to dy / dx_i | |||
| # None means y is not a function of x_i | |||
| backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( | |||
| op, ctx.inputs, ctx.outputs, input_requires_grad | |||
| ) | |||
| assert len(ctx.outputs) == len(output_need_grad) | |||
| if not any(output_need_grad): | |||
| return | |||
| opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | |||
| if isinstance(op, RemoteSend): | |||
| manager.remote_send_cache.append(opnode) | |||
| opnode.backward = backward | |||
| outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | |||
| opnode.backward_allow_noinput = check_backward_allow_noinput(op) | |||
| opnode.has_grad_fn = get_op_has_grad_fn(op) | |||
| return tuple(outputs) | |||
| @apply.register() | |||
| def _(op: Const, *_: typing.Optional[Tracer]): | |||
| return None | |||
| class Grad: | |||
| def __init__(self): | |||
| self._impl = core2.GradKey() | |||
| @@ -8,9 +8,6 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| # from .._imperative_rt.core2 import Tensor | |||
| from ..tensor.core import OpBase, TensorBase, apply | |||
| class Const: | |||
| def __init__(self, value=None, *, dtype=None, device=None): | |||
| @@ -13,12 +13,9 @@ import sys | |||
| import typing | |||
| from abc import ABC | |||
| from .multipledispatch import Dispatcher | |||
| class OpBase(ABC): | |||
| def __call__(self, *args): | |||
| return apply(self, *args) | |||
| class OpBase: | |||
| pass | |||
| class TensorBase: | |||
| @@ -27,22 +24,3 @@ class TensorBase: | |||
| class TensorWrapperBase: | |||
| pass | |||
| apply = Dispatcher("apply") | |||
| OpBase.apply = apply | |||
| @apply.register() | |||
| def _(op: OpBase, *args: TensorBase): | |||
| raise NotImplementedError | |||
| @apply.register() | |||
| def _(op: OpBase, *args: TensorWrapperBase): | |||
| assert args | |||
| Wrapper = type(args[0]) | |||
| outputs = apply(op, *(i.__wrapped__ for i in args)) | |||
| assert isinstance(outputs, tuple) | |||
| return tuple(map(Wrapper, outputs)) | |||
| @@ -1,154 +0,0 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..ops.builtin import OpDef | |||
| from .core import TensorBase, TensorWrapperBase, apply | |||
| class Function: | |||
| """ | |||
| Defines a block of operations with customizable differentiation. | |||
| The computation should be defined in ``forward`` method, with gradient | |||
| computation defined in ``backward`` method. | |||
| Each instance of ``Function`` should be used only once during forwardding. | |||
| Examples: | |||
| .. testcode:: | |||
| class Sigmoid(Function): | |||
| def forward(self, x): | |||
| y = 1 / (1 + F.exp(-x)) | |||
| self.y = y | |||
| return y | |||
| def backward(self, output_grads): | |||
| y = self.y | |||
| return output_grads * y * (1-y) | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| pass | |||
| def __call__(self, *args): | |||
| ret = apply(self, *args) | |||
| if type(ret) == tuple and len(ret) == 1: | |||
| return ret[0] | |||
| return ret | |||
| def forward(self, *args, **kwargs): | |||
| """ | |||
| Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. | |||
| :param input: input tensors. | |||
| :return: a tuple of Tensor or a single Tensor. | |||
| .. note:: | |||
| This method should return a tuple of Tensor or a single Tensor representing the output | |||
| of the function. | |||
| """ | |||
| raise NotImplementedError | |||
| def backward(self, *output_grads): | |||
| """ | |||
| Compute the gradient of the forward function. It must be overriden by all subclasses. | |||
| :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`. | |||
| .. note:: | |||
| In case when some tensors of outputs are not related to loss function, the corresponding | |||
| values in ``output_grads`` would be ``None``. | |||
| .. note:: | |||
| This method should return a tuple which containing the gradients of all inputs, in the same order | |||
| as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned | |||
| instead if there is only one input. If users want to stop the propagation of some gradients, | |||
| the corresponding returned values should be set ``None`` . | |||
| """ | |||
| raise NotImplementedError | |||
| def get_backward_fn(self): | |||
| if self.backward is None: | |||
| return None | |||
| def _backward(*output_grads): | |||
| if type(output_grads) is tuple: | |||
| _output_grads = [ | |||
| TensorWrapper(i) if i is not None else i for i in output_grads | |||
| ] | |||
| else: | |||
| _output_grads = ( | |||
| TensorWrapper(output_grads) | |||
| if output_grads is not None | |||
| else output_grads, | |||
| ) | |||
| ret = self.backward(*_output_grads) | |||
| if type(ret) is not tuple: | |||
| ret = (ret,) | |||
| ret = tuple( | |||
| i.__wrapped__ if isinstance(i, TensorWrapper) else i for i in ret | |||
| ) | |||
| return ret | |||
| return _backward | |||
| Function.apply = Function.__call__ | |||
| @apply.register() | |||
| def _(op: Function, *args: TensorWrapperBase): | |||
| assert args | |||
| Wrapper = type(args[0]) | |||
| # compute the value for self define function | |||
| extra_data_dic = {} | |||
| for arg in args: | |||
| extra_data_dic[arg.__wrapped__] = arg.__wrapped__._extra_data | |||
| arg.__wrapped__._extra_data = {} | |||
| rets = op.forward(*args) | |||
| for arg in args: | |||
| arg.__wrapped__._extra_data = extra_data_dic[arg.__wrapped__] | |||
| # update the gradient information for self define function | |||
| inputs = tuple(map(lambda i: i.__wrapped__, args)) | |||
| outputs = ( | |||
| tuple(map(lambda i: i.__wrapped__, rets)) | |||
| if type(rets) is tuple | |||
| else (rets.__wrapped__,) | |||
| ) | |||
| for output in outputs: | |||
| if output not in inputs: | |||
| output._extra_data = {} | |||
| with push_context() as ctx: | |||
| ctx.inputs = inputs | |||
| ctx.outputs = outputs | |||
| for k in set().union(*(i._extra_data for i in inputs if isinstance(i, Tensor))): | |||
| ctx.key = k | |||
| data = tuple( | |||
| i._extra_data.get(k) if isinstance(i, Tensor) else i for i in inputs | |||
| ) | |||
| # data are instances of Tracer | |||
| # dispatched to apply.add@grad.py | |||
| rets = apply(op, *data) | |||
| if rets is not None: | |||
| assert len(outputs) == len(rets) | |||
| for t, i in zip(outputs, rets): | |||
| t._extra_data[k] = i | |||
| return tuple(map(Wrapper, outputs)) | |||
| @@ -1,53 +0,0 @@ | |||
| # Copyright (c) 2014 Matthew Rocklin | |||
| # | |||
| # All rights reserved. | |||
| # | |||
| # Redistribution and use in source and binary forms, with or without | |||
| # modification, are permitted provided that the following conditions are met: | |||
| # | |||
| # a. Redistributions of source code must retain the above copyright notice, | |||
| # this list of conditions and the following disclaimer. | |||
| # b. Redistributions in binary form must reproduce the above copyright | |||
| # notice, this list of conditions and the following disclaimer in the | |||
| # documentation and/or other materials provided with the distribution. | |||
| # c. Neither the name of multipledispatch nor the names of its contributors | |||
| # may be used to endorse or promote products derived from this software | |||
| # without specific prior written permission. | |||
| # | |||
| # | |||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||
| # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
| # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||
| # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||
| # DAMAGE. | |||
| # | |||
| # -------------------------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # -------------------------------------------------------------------------------------- | |||
| # This directory is a fork of multipledispatch. | |||
| # | |||
| # Repo: https://github.com/mrocklin/multipledispatch | |||
| # Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834 | |||
| # Authors: Matthew Rocklin et al. | |||
| # | |||
| # The original LICENSE file is included in the ACKNOWLEDGEMENT file under | |||
| # MegEngine root directory. | |||
| from .core import dispatch | |||
| from .dispatcher import Dispatcher | |||
| @@ -1,165 +0,0 @@ | |||
| # Copyright (c) 2014 Matthew Rocklin | |||
| # | |||
| # All rights reserved. | |||
| # | |||
| # Redistribution and use in source and binary forms, with or without | |||
| # modification, are permitted provided that the following conditions are met: | |||
| # | |||
| # a. Redistributions of source code must retain the above copyright notice, | |||
| # this list of conditions and the following disclaimer. | |||
| # b. Redistributions in binary form must reproduce the above copyright | |||
| # notice, this list of conditions and the following disclaimer in the | |||
| # documentation and/or other materials provided with the distribution. | |||
| # c. Neither the name of multipledispatch nor the names of its contributors | |||
| # may be used to endorse or promote products derived from this software | |||
| # without specific prior written permission. | |||
| # | |||
| # | |||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||
| # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
| # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||
| # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||
| # DAMAGE. | |||
| # | |||
| # -------------------------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # -------------------------------------------------------------------------------------- | |||
| from collections import OrderedDict | |||
| from .utils import _toposort, groupby | |||
| from .variadic import isvariadic | |||
| class AmbiguityWarning(Warning): | |||
| pass | |||
| def supercedes(a, b): | |||
| """ A is consistent and strictly more specific than B """ | |||
| if len(a) < len(b): | |||
| # only case is if a is empty and b is variadic | |||
| return not a and len(b) == 1 and isvariadic(b[-1]) | |||
| elif len(a) == len(b): | |||
| return all(map(issubclass, a, b)) | |||
| else: | |||
| # len(a) > len(b) | |||
| p1 = 0 | |||
| p2 = 0 | |||
| while p1 < len(a) and p2 < len(b): | |||
| cur_a = a[p1] | |||
| cur_b = b[p2] | |||
| if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||
| if not issubclass(cur_a, cur_b): | |||
| return False | |||
| p1 += 1 | |||
| p2 += 1 | |||
| elif isvariadic(cur_a): | |||
| assert p1 == len(a) - 1 | |||
| return p2 == len(b) - 1 and issubclass(cur_a, cur_b) | |||
| elif isvariadic(cur_b): | |||
| assert p2 == len(b) - 1 | |||
| if not issubclass(cur_a, cur_b): | |||
| return False | |||
| p1 += 1 | |||
| return p2 == len(b) - 1 and p1 == len(a) | |||
| def consistent(a, b): | |||
| """ It is possible for an argument list to satisfy both A and B """ | |||
| # Need to check for empty args | |||
| if not a: | |||
| return not b or isvariadic(b[0]) | |||
| if not b: | |||
| return not a or isvariadic(a[0]) | |||
| # Non-empty args check for mutual subclasses | |||
| if len(a) == len(b): | |||
| return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) | |||
| else: | |||
| p1 = 0 | |||
| p2 = 0 | |||
| while p1 < len(a) and p2 < len(b): | |||
| cur_a = a[p1] | |||
| cur_b = b[p2] | |||
| if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): | |||
| return False | |||
| if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||
| p1 += 1 | |||
| p2 += 1 | |||
| elif isvariadic(cur_a): | |||
| p2 += 1 | |||
| elif isvariadic(cur_b): | |||
| p1 += 1 | |||
| # We only need to check for variadic ends | |||
| # Variadic types are guaranteed to be the last element | |||
| return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a) | |||
| def ambiguous(a, b): | |||
| """ A is consistent with B but neither is strictly more specific """ | |||
| return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) | |||
| def ambiguities(signatures): | |||
| """ All signature pairs such that A is ambiguous with B """ | |||
| signatures = list(map(tuple, signatures)) | |||
| return set( | |||
| (a, b) | |||
| for a in signatures | |||
| for b in signatures | |||
| if hash(a) < hash(b) | |||
| and ambiguous(a, b) | |||
| and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) | |||
| ) | |||
| def super_signature(signatures): | |||
| """ A signature that would break ambiguities """ | |||
| n = len(signatures[0]) | |||
| assert all(len(s) == n for s in signatures) | |||
| return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)] | |||
| def edge(a, b, tie_breaker=hash): | |||
| """ A should be checked before B | |||
| Tie broken by tie_breaker, defaults to ``hash`` | |||
| """ | |||
| # A either supercedes B and B does not supercede A or if B does then call | |||
| # tie_breaker | |||
| return supercedes(a, b) and ( | |||
| not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) | |||
| ) | |||
| def ordering(signatures): | |||
| """ A sane ordering of signatures to check, first to last | |||
| Topoological sort of edges as given by ``edge`` and ``supercedes`` | |||
| """ | |||
| signatures = list(map(tuple, signatures)) | |||
| edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] | |||
| edges = groupby(lambda x: x[0], edges) | |||
| for s in signatures: | |||
| if s not in edges: | |||
| edges[s] = [] | |||
| edges = OrderedDict((k, [b for a, b in v]) for k, v in edges.items()) | |||
| return _toposort(edges) | |||
| @@ -1,130 +0,0 @@ | |||
| # Copyright (c) 2014 Matthew Rocklin | |||
| # | |||
| # All rights reserved. | |||
| # | |||
| # Redistribution and use in source and binary forms, with or without | |||
| # modification, are permitted provided that the following conditions are met: | |||
| # | |||
| # a. Redistributions of source code must retain the above copyright notice, | |||
| # this list of conditions and the following disclaimer. | |||
| # b. Redistributions in binary form must reproduce the above copyright | |||
| # notice, this list of conditions and the following disclaimer in the | |||
| # documentation and/or other materials provided with the distribution. | |||
| # c. Neither the name of multipledispatch nor the names of its contributors | |||
| # may be used to endorse or promote products derived from this software | |||
| # without specific prior written permission. | |||
| # | |||
| # | |||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||
| # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
| # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||
| # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||
| # DAMAGE. | |||
| # | |||
| # -------------------------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # -------------------------------------------------------------------------------------- | |||
| import inspect | |||
| import sys | |||
| from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn | |||
| global_namespace = dict() | |||
| def dispatch(*types, **kwargs): | |||
| """ Dispatch function on the types of the inputs | |||
| Supports dispatch on all non-keyword arguments. | |||
| Collects implementations based on the function name. Ignores namespaces. | |||
| If ambiguous type signatures occur a warning is raised when the function is | |||
| defined suggesting the additional method to break the ambiguity. | |||
| Examples | |||
| -------- | |||
| >>> @dispatch(int) | |||
| ... def f(x): | |||
| ... return x + 1 | |||
| >>> @dispatch(float) | |||
| ... def f(x): | |||
| ... return x - 1 | |||
| >>> f(3) | |||
| 4 | |||
| >>> f(3.0) | |||
| 2.0 | |||
| Specify an isolated namespace with the namespace keyword argument | |||
| >>> my_namespace = dict() | |||
| >>> @dispatch(int, namespace=my_namespace) | |||
| ... def foo(x): | |||
| ... return x + 1 | |||
| Dispatch on instance methods within classes | |||
| >>> class MyClass(object): | |||
| ... @dispatch(list) | |||
| ... def __init__(self, data): | |||
| ... self.data = data | |||
| ... @dispatch(int) | |||
| ... def __init__(self, datum): | |||
| ... self.data = [datum] | |||
| """ | |||
| namespace = kwargs.get("namespace", global_namespace) | |||
| types = tuple(types) | |||
| def _df(func): | |||
| name = func.__name__ | |||
| if ismethod(func): | |||
| dispatcher = inspect.currentframe().f_back.f_locals.get( | |||
| name, MethodDispatcher(name), | |||
| ) | |||
| else: | |||
| if name not in namespace: | |||
| namespace[name] = Dispatcher(name) | |||
| dispatcher = namespace[name] | |||
| dispatcher.add(types, func) | |||
| return dispatcher | |||
| return _df | |||
| def ismethod(func): | |||
| """ Is func a method? | |||
| Note that this has to work as the method is defined but before the class is | |||
| defined. At this stage methods look like functions. | |||
| """ | |||
| if hasattr(inspect, "signature"): | |||
| signature = inspect.signature(func) | |||
| return signature.parameters.get("self", None) is not None | |||
| else: | |||
| if sys.version_info.major < 3: | |||
| spec = inspect.getargspec(func) | |||
| else: | |||
| spec = inspect.getfullargspec(func) | |||
| return spec and spec.args and spec.args[0] == "self" | |||
| @@ -1,445 +0,0 @@ | |||
| # Copyright (c) 2014 Matthew Rocklin | |||
| # | |||
| # All rights reserved. | |||
| # | |||
| # Redistribution and use in source and binary forms, with or without | |||
| # modification, are permitted provided that the following conditions are met: | |||
| # | |||
| # a. Redistributions of source code must retain the above copyright notice, | |||
| # this list of conditions and the following disclaimer. | |||
| # b. Redistributions in binary form must reproduce the above copyright | |||
| # notice, this list of conditions and the following disclaimer in the | |||
| # documentation and/or other materials provided with the distribution. | |||
| # c. Neither the name of multipledispatch nor the names of its contributors | |||
| # may be used to endorse or promote products derived from this software | |||
| # without specific prior written permission. | |||
| # | |||
| # | |||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||
| # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
| # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||
| # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||
| # DAMAGE. | |||
| # | |||
| # -------------------------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # -------------------------------------------------------------------------------------- | |||
| import copy | |||
| import inspect | |||
| import itertools as itl | |||
| from warnings import warn | |||
| from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher | |||
| from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature | |||
| from .utils import expand_tuples, parse_union | |||
| from .variadic import Variadic, isvariadic | |||
| def ambiguity_warn(dispatcher, ambiguities): | |||
| """ Raise warning when ambiguity is detected | |||
| Parameters | |||
| ---------- | |||
| dispatcher : Dispatcher | |||
| The dispatcher on which the ambiguity was detected | |||
| ambiguities : set | |||
| Set of type signature pairs that are ambiguous within this dispatcher | |||
| See Also: | |||
| Dispatcher.add | |||
| warning_text | |||
| """ | |||
| warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) | |||
| def variadic_signature_matches_iter(types, full_signature): | |||
| """ | |||
| Check if a set of input types matches a variadic signature. | |||
| Notes | |||
| ----- | |||
| The algorithm is as follows: | |||
| Initialize the current signature to the first in the sequence | |||
| For each type in `types`: | |||
| If the current signature is variadic | |||
| If the type matches the signature | |||
| yield True | |||
| Else | |||
| Try to get the next signature | |||
| If no signatures are left we can't possibly have a match | |||
| so yield False | |||
| Else | |||
| yield True if the type matches the current signature | |||
| Get the next signature | |||
| """ | |||
| sigiter = iter(full_signature) | |||
| sig = next(sigiter) | |||
| for typ in types: | |||
| matches = issubclass(typ, sig) | |||
| yield matches | |||
| if not isvariadic(sig): | |||
| # we're not matching a variadic argument, so move to the next | |||
| # element in the signature | |||
| sig = next(sigiter) | |||
| else: | |||
| try: | |||
| sig = next(sigiter) | |||
| except StopIteration: | |||
| assert isvariadic(sig) | |||
| yield True | |||
| else: | |||
| # We have signature items left over, so all of our arguments | |||
| # haven't matched | |||
| yield False | |||
| def variadic_signature_matches(types, full_signature): | |||
| # No arguments always matches a variadic signature | |||
| assert full_signature | |||
| return all(variadic_signature_matches_iter(types, full_signature)) | |||
| def get_func_signature(function): | |||
| sig = inspect.signature(function) | |||
| types = [] | |||
| for p in sig.parameters.values(): | |||
| ann = p.annotation | |||
| ann = parse_union(ann) or ann | |||
| if p.kind in ( | |||
| inspect.Parameter.POSITIONAL_ONLY, | |||
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||
| ): | |||
| types.append(ann) | |||
| if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||
| types.append([ann]) | |||
| return tuple(types) | |||
| class Frame: | |||
| __slots__ = "args", "types", "mro", "mro_offset" | |||
| class Dispatcher(CDispatcher): | |||
| """ Dispatch methods based on type signature | |||
| Use ``dispatch`` to add implementations | |||
| Examples | |||
| -------- | |||
| >>> from multipledispatch import dispatch | |||
| >>> @dispatch(int) | |||
| ... def f(x): | |||
| ... return x + 1 | |||
| >>> @dispatch(float) | |||
| ... def f(x): | |||
| ... return x - 1 | |||
| >>> f(3) | |||
| 4 | |||
| >>> f(3.0) | |||
| 2.0 | |||
| """ | |||
| __slots__ = "__name__", "name", "funcs", "_ordering", "doc" | |||
| def __init__(self, name, doc=None): | |||
| self.name = self.__name__ = name | |||
| self.funcs = {} | |||
| self.doc = doc | |||
| def register(self, *types, **kwargs): | |||
| """ register dispatcher with new implementation | |||
| >>> f = Dispatcher('f') | |||
| >>> @f.register(int) | |||
| ... def inc(x): | |||
| ... return x + 1 | |||
| >>> @f.register(float) | |||
| ... def dec(x): | |||
| ... return x - 1 | |||
| >>> @f.register(list) | |||
| ... @f.register(tuple) | |||
| ... def reverse(x): | |||
| ... return x[::-1] | |||
| >>> f(1) | |||
| 2 | |||
| >>> f(1.0) | |||
| 0.0 | |||
| >>> f([1, 2, 3]) | |||
| [3, 2, 1] | |||
| """ | |||
| def _df(func): | |||
| self.add(types, func, **kwargs) | |||
| return func | |||
| return _df | |||
| def add(self, signature, func): | |||
| """ Add new types/method pair to dispatcher | |||
| >>> D = Dispatcher('add') | |||
| >>> D.add((int, int), lambda x, y: x + y) | |||
| >>> D.add((float, float), lambda x, y: x + y) | |||
| >>> D(1, 2) | |||
| 3 | |||
| >>> D(1, 2.0) | |||
| Traceback (most recent call last): | |||
| ... | |||
| NotImplementedError: Could not find signature for add: <int, float> | |||
| When ``add`` detects a warning it calls the ``on_ambiguity`` callback | |||
| with a dispatcher/itself, and a set of ambiguous type signature pairs | |||
| as inputs. See ``ambiguity_warn`` for an example. | |||
| """ | |||
| # Handle annotations | |||
| if not signature: | |||
| signature = get_func_signature(func) | |||
| # Handle union types | |||
| if any(isinstance(typ, tuple) for typ in signature): | |||
| for typs in expand_tuples(signature): | |||
| self.add(typs, func) | |||
| return | |||
| new_signature = [] | |||
| for index, typ in enumerate(signature, start=1): | |||
| if not isinstance(typ, (type, list)): | |||
| str_sig = ", ".join( | |||
| c.__name__ if isinstance(c, type) else str(c) for c in signature | |||
| ) | |||
| raise TypeError( | |||
| "Tried to dispatch on non-type: %s\n" | |||
| "In signature: <%s>\n" | |||
| "In function: %s" % (typ, str_sig, self.name) | |||
| ) | |||
| # handle variadic signatures | |||
| if isinstance(typ, list): | |||
| if index != len(signature): | |||
| raise TypeError("Variadic signature must be the last element") | |||
| if len(typ) != 1: | |||
| raise TypeError( | |||
| "Variadic signature must contain exactly one element. " | |||
| "To use a variadic union type place the desired types " | |||
| "inside of a tuple, e.g., [(int, str)]" | |||
| ) | |||
| new_signature.append(Variadic[typ[0]]) | |||
| else: | |||
| new_signature.append(typ) | |||
| l = self.funcs.setdefault(tuple(new_signature), []) | |||
| for i in l: | |||
| if i is func: | |||
| raise ValueError("already registered") | |||
| l.append(func) | |||
| self.enable(func) | |||
| self.clear_cache() | |||
| try: | |||
| del self._ordering | |||
| except AttributeError: | |||
| pass | |||
| @property | |||
| def ordering(self): | |||
| try: | |||
| return self._ordering | |||
| except AttributeError: | |||
| return self.reorder() | |||
| def reorder(self, on_ambiguity=ambiguity_warn): | |||
| self._ordering = od = ordering(self.funcs) | |||
| amb = ambiguities(self.funcs) | |||
| if amb: | |||
| on_ambiguity(self, amb) | |||
| return od | |||
| def __str__(self): | |||
| return "<dispatched %s>" % self.name | |||
| __repr__ = __str__ | |||
| def dispatch(self, *types): | |||
| """ | |||
| Deterimine appropriate implementation for this type signature | |||
| This method is internal. Users should call this object as a function. | |||
| Implementation resolution occurs within the ``__call__`` method. | |||
| >>> from multipledispatch import dispatch | |||
| >>> @dispatch(int) | |||
| ... def inc(x): | |||
| ... return x + 1 | |||
| >>> implementation = inc.dispatch(int) | |||
| >>> implementation(3) | |||
| 4 | |||
| >>> print(inc.dispatch(float)) | |||
| None | |||
| See Also: | |||
| ``multipledispatch.conflict`` - module to determine resolution order | |||
| """ | |||
| if types in self.funcs: | |||
| return self.funcs[types][-1] | |||
| for f in self.dispatch_iter(*types): | |||
| return f | |||
| def dispatch_iter(self, *types): | |||
| n = len(types) | |||
| for signature in self.ordering: | |||
| if ( | |||
| len(signature) == n | |||
| and all(map(issubclass, types, signature)) | |||
| or len(signature) | |||
| and isvariadic(signature[-1]) | |||
| and variadic_signature_matches(types, signature) | |||
| ): | |||
| yield from self.funcs[signature][::-1] | |||
| def __getstate__(self): | |||
| return {"name": self.name, "funcs": self.funcs} | |||
| def __setstate__(self, d): | |||
| self.name = d["name"] | |||
| self.funcs = d["funcs"] | |||
| self._ordering = ordering(self.funcs) | |||
| self._cache = dict() | |||
| @property | |||
| def __doc__(self): | |||
| docs = ["Multiply dispatched method: %s" % self.name] | |||
| if self.doc: | |||
| docs.append(self.doc) | |||
| other = [] | |||
| for sig in self.ordering[::-1]: | |||
| funcs = self.funcs[sig] | |||
| s = "Inputs: <%s>\n" % str_signature(sig) | |||
| sep = "-" * len(s) + "\n" | |||
| for i, func in enumerate(funcs): | |||
| s += sep | |||
| if len(funcs) > 1: | |||
| s += "[Handler %d]\n\n" % (i + 1) | |||
| if i: | |||
| s += "\n\n" | |||
| if func.__doc__: | |||
| s += func.__doc__.strip() | |||
| else: | |||
| s += repr(func) + "\n" | |||
| docs.append(s) | |||
| return "\n\n".join(docs) | |||
| def _help(self, *args): | |||
| return self.dispatch(*map(type, args)).__doc__ | |||
| def help(self, *args, **kwargs): | |||
| """ Print docstring for the function corresponding to inputs """ | |||
| print(self._help(*args)) | |||
| def _source(self, *args): | |||
| func = self.dispatch(*map(type, args)) | |||
| if not func: | |||
| raise TypeError("No function found") | |||
| return source(func) | |||
| def source(self, *args, **kwargs): | |||
| """ Print source code for the function corresponding to inputs """ | |||
| print(self._source(*args)) | |||
| def source(func): | |||
| s = "File: %s\n\n" % inspect.getsourcefile(func) | |||
| s = s + inspect.getsource(func) | |||
| return s | |||
| class MethodDispatcher(Dispatcher): | |||
| """ Dispatch methods based on type signature | |||
| See Also: | |||
| Dispatcher | |||
| """ | |||
| __slots__ = ("obj", "cls") | |||
| @classmethod | |||
| def get_func_params(cls, func): | |||
| if hasattr(inspect, "signature"): | |||
| sig = inspect.signature(func) | |||
| return itl.islice(sig.parameters.values(), 1, None) | |||
| def __get__(self, instance, owner): | |||
| self.obj = instance | |||
| self.cls = owner | |||
| return self | |||
| def __call__(self, *args, **kwargs): | |||
| types = tuple([type(arg) for arg in args]) | |||
| func = self.dispatch(*types) | |||
| if not func: | |||
| raise NotImplementedError( | |||
| "Could not find signature for %s: <%s>" | |||
| % (self.name, str_signature(types)) | |||
| ) | |||
| return func(self.obj, *args, **kwargs) | |||
| def str_signature(sig): | |||
| """ String representation of type signature | |||
| >>> str_signature((int, float)) | |||
| 'int, float' | |||
| """ | |||
| return ", ".join(cls.__name__ for cls in sig) | |||
| def warning_text(name, amb): | |||
| """ The text for ambiguity warnings """ | |||
| text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) | |||
| text += "The following signatures may result in ambiguous behavior:\n" | |||
| for pair in amb: | |||
| text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" | |||
| text += "\n\nConsider making the following additions:\n\n" | |||
| text += "\n\n".join( | |||
| [ | |||
| "@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name | |||
| for s in amb | |||
| ] | |||
| ) | |||
| return text | |||
| @@ -1,210 +0,0 @@ | |||
| # Copyright (c) 2014 Matthew Rocklin | |||
| # | |||
| # All rights reserved. | |||
| # | |||
| # Redistribution and use in source and binary forms, with or without | |||
| # modification, are permitted provided that the following conditions are met: | |||
| # | |||
| # a. Redistributions of source code must retain the above copyright notice, | |||
| # this list of conditions and the following disclaimer. | |||
| # b. Redistributions in binary form must reproduce the above copyright | |||
| # notice, this list of conditions and the following disclaimer in the | |||
| # documentation and/or other materials provided with the distribution. | |||
| # c. Neither the name of multipledispatch nor the names of its contributors | |||
| # may be used to endorse or promote products derived from this software | |||
| # without specific prior written permission. | |||
| # | |||
| # | |||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||
| # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
| # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||
| # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||
| # DAMAGE. | |||
| # | |||
| # -------------------------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # -------------------------------------------------------------------------------------- | |||
| import sys | |||
| import typing | |||
| from collections import OrderedDict | |||
| def raises(err, lamda): | |||
| try: | |||
| lamda() | |||
| return False | |||
| except err: | |||
| return True | |||
| def expand_tuples(L): | |||
| """ | |||
| >>> expand_tuples([1, (2, 3)]) | |||
| [(1, 2), (1, 3)] | |||
| >>> expand_tuples([1, 2]) | |||
| [(1, 2)] | |||
| """ | |||
| if not L: | |||
| return [()] | |||
| elif not isinstance(L[0], tuple): | |||
| rest = expand_tuples(L[1:]) | |||
| return [(L[0],) + t for t in rest] | |||
| else: | |||
| rest = expand_tuples(L[1:]) | |||
| return [(item,) + t for t in rest for item in L[0]] | |||
| # Taken from theano/theano/gof/sched.py | |||
| # Avoids licensing issues because this was written by Matthew Rocklin | |||
| def _toposort(edges): | |||
| """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) | |||
| inputs: | |||
| edges - a dict of the form {a: {b, c}} where b and c depend on a | |||
| outputs: | |||
| L - an ordered list of nodes that satisfy the dependencies of edges | |||
| >>> _toposort({1: (2, 3), 2: (3, )}) | |||
| [1, 2, 3] | |||
| Closely follows the wikipedia page [2] | |||
| [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", | |||
| Communications of the ACM | |||
| [2] http://en.wikipedia.org/wiki/Toposort#Algorithms | |||
| """ | |||
| incoming_edges = reverse_dict(edges) | |||
| incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) | |||
| S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) | |||
| L = [] | |||
| while S: | |||
| n, _ = S.popitem() | |||
| L.append(n) | |||
| for m in edges.get(n, ()): | |||
| assert n in incoming_edges[m] | |||
| incoming_edges[m].remove(n) | |||
| if not incoming_edges[m]: | |||
| S[m] = None | |||
| if any(incoming_edges.get(v, None) for v in edges): | |||
| raise ValueError("Input has cycles") | |||
| return L | |||
| def reverse_dict(d): | |||
| """ | |||
| Reverses direction of dependence dict | |||
| >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} | |||
| >>> reverse_dict(d) # doctest: +SKIP | |||
| {1: ('a',), 2: ('a', 'b'), 3: ('b',)} | |||
| :note: dict order are not deterministic. As we iterate on the | |||
| input dict, it make the output of this function depend on the | |||
| dict order. So this function output order should be considered | |||
| as undeterministic. | |||
| """ | |||
| result = OrderedDict() | |||
| for key in d: | |||
| for val in d[key]: | |||
| result[val] = result.get(val, tuple()) + (key,) | |||
| return result | |||
| # Taken from toolz | |||
| # Avoids licensing issues because this version was authored by Matthew Rocklin | |||
| def groupby(func, seq): | |||
| """ Group a collection by a key function | |||
| >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] | |||
| >>> groupby(len, names) # doctest: +SKIP | |||
| {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} | |||
| >>> iseven = lambda x: x % 2 == 0 | |||
| >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP | |||
| {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} | |||
| See Also: | |||
| ``countby`` | |||
| """ | |||
| d = OrderedDict() | |||
| for item in seq: | |||
| key = func(item) | |||
| if key not in d: | |||
| d[key] = list() | |||
| d[key].append(item) | |||
| return d | |||
| def typename(type): | |||
| """ | |||
| Get the name of `type`. | |||
| Parameters | |||
| ---------- | |||
| type : Union[Type, Tuple[Type]] | |||
| Returns | |||
| ------- | |||
| str | |||
| The name of `type` or a tuple of the names of the types in `type`. | |||
| Examples | |||
| -------- | |||
| >>> typename(int) | |||
| 'int' | |||
| >>> typename((int, float)) | |||
| '(int, float)' | |||
| """ | |||
| try: | |||
| return type.__name__ | |||
| except AttributeError: | |||
| if len(type) == 1: | |||
| return typename(*type) | |||
| return "(%s)" % ", ".join(map(typename, type)) | |||
| # parse typing.Union | |||
| def parse_union(ann): | |||
| if hasattr(typing, "UnionMeta"): | |||
| if type(ann) is not typing.UnionMeta: | |||
| return | |||
| return ann.__union_params__ | |||
| elif hasattr(typing, "_Union"): | |||
| if type(ann) is not typing._Union: | |||
| return | |||
| return ann.__args__ | |||
| elif hasattr(typing, "_GenericAlias"): | |||
| if type(ann) is not typing._GenericAlias: | |||
| if type(ann) is not typing.Union: | |||
| return | |||
| else: | |||
| if ann.__origin__ is not typing.Union: | |||
| return | |||
| return ann.__args__ | |||
| elif hasattr(typing, "Union"): | |||
| if typing.get_origin(ann) is not typing.Union: | |||
| return | |||
| return typing.get_args(ann) | |||
| else: | |||
| raise NotImplementedError("unsupported Python version") | |||
| @@ -1,140 +0,0 @@ | |||
| # Copyright (c) 2014 Matthew Rocklin | |||
| # | |||
| # All rights reserved. | |||
| # | |||
| # Redistribution and use in source and binary forms, with or without | |||
| # modification, are permitted provided that the following conditions are met: | |||
| # | |||
| # a. Redistributions of source code must retain the above copyright notice, | |||
| # this list of conditions and the following disclaimer. | |||
| # b. Redistributions in binary form must reproduce the above copyright | |||
| # notice, this list of conditions and the following disclaimer in the | |||
| # documentation and/or other materials provided with the distribution. | |||
| # c. Neither the name of multipledispatch nor the names of its contributors | |||
| # may be used to endorse or promote products derived from this software | |||
| # without specific prior written permission. | |||
| # | |||
| # | |||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
| # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||
| # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||
| # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||
| # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||
| # DAMAGE. | |||
| # | |||
| # -------------------------------------------------------------------------------------- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # | |||
| # This file has been modified by Megvii ("Megvii Modifications"). | |||
| # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. | |||
| # -------------------------------------------------------------------------------------- | |||
| from .utils import typename | |||
| class VariadicSignatureType(type): | |||
| # checking if subclass is a subclass of self | |||
| def __subclasscheck__(self, subclass): | |||
| other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) | |||
| return subclass is self or all( | |||
| issubclass(other, self.variadic_type) for other in other_type | |||
| ) | |||
| def __eq__(self, other): | |||
| """ | |||
| Return True if other has the same variadic type | |||
| Parameters | |||
| ---------- | |||
| other : object (type) | |||
| The object (type) to check | |||
| Returns | |||
| ------- | |||
| bool | |||
| Whether or not `other` is equal to `self` | |||
| """ | |||
| return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type) | |||
| def __hash__(self): | |||
| return hash((type(self), frozenset(self.variadic_type))) | |||
| def isvariadic(obj): | |||
| """ | |||
| Check whether the type `obj` is variadic. | |||
| Parameters | |||
| ---------- | |||
| obj : type | |||
| The type to check | |||
| Returns | |||
| ------- | |||
| bool | |||
| Whether or not `obj` is variadic | |||
| Examples | |||
| -------- | |||
| >>> isvariadic(int) | |||
| False | |||
| >>> isvariadic(Variadic[int]) | |||
| True | |||
| """ | |||
| return isinstance(obj, VariadicSignatureType) | |||
| class VariadicSignatureMeta(type): | |||
| """ | |||
| A metaclass that overrides ``__getitem__`` on the class. This is used to | |||
| generate a new type for Variadic signatures. See the Variadic class for | |||
| examples of how this behaves. | |||
| """ | |||
| def __getitem__(self, variadic_type): | |||
| if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): | |||
| raise ValueError( | |||
| "Variadic types must be type or tuple of types" | |||
| " (Variadic[int] or Variadic[(int, float)]" | |||
| ) | |||
| if not isinstance(variadic_type, tuple): | |||
| variadic_type = (variadic_type,) | |||
| return VariadicSignatureType( | |||
| "Variadic[%s]" % typename(variadic_type), | |||
| (), | |||
| dict(variadic_type=variadic_type, __slots__=()), | |||
| ) | |||
| class Variadic(metaclass=VariadicSignatureMeta): | |||
| """ | |||
| A class whose getitem method can be used to generate a new type | |||
| representing a specific variadic signature. | |||
| Examples | |||
| -------- | |||
| >>> Variadic[int] # any number of int arguments | |||
| <class 'multipledispatch.variadic.Variadic[int]'> | |||
| >>> Variadic[(int, str)] # any number of one of int or str arguments | |||
| <class 'multipledispatch.variadic.Variadic[(int, str)]'> | |||
| >>> issubclass(int, Variadic[int]) | |||
| True | |||
| >>> issubclass(int, Variadic[(int, str)]) | |||
| True | |||
| >>> issubclass(str, Variadic[(int, str)]) | |||
| True | |||
| >>> issubclass(float, Variadic[(int, str)]) | |||
| False | |||
| """ | |||
| @@ -1,136 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import functools | |||
| import numpy as np | |||
| from ..._imperative_rt import CompNode, DeviceTensorND | |||
| from ..._imperative_rt.imperative import ( | |||
| _drop, | |||
| _get_dev_tensor, | |||
| _swap_in, | |||
| _swap_out, | |||
| apply_op, | |||
| delete, | |||
| get_device, | |||
| get_dtype, | |||
| get_shape, | |||
| get_value, | |||
| put, | |||
| ) | |||
| from ..._wrap import device as as_device | |||
| from ...ops.builtin import Copy, OpDef, TypeCvt | |||
| from ...ops.special import Const | |||
| from ..core import OpBase, TensorBase, apply | |||
| class RawTensor(TensorBase): | |||
| _init_cb = None | |||
| _del_cb = None | |||
| _handle = None | |||
| def __init__(self, handle=None, isscalar=False): | |||
| self._handle = handle | |||
| self._isscalar = isscalar | |||
| if handle is not None: | |||
| if self._init_cb: | |||
| self._init_cb() | |||
| @property | |||
| def dtype(self): | |||
| return get_dtype(self._handle) | |||
| @property | |||
| def device(self): | |||
| return as_device(get_device(self._handle)) | |||
| @property | |||
| def shape(self): | |||
| if self._isscalar: | |||
| return () | |||
| return get_shape(self._handle) | |||
| def numpy(self): | |||
| ret = get_value(self._handle) | |||
| if self._isscalar: | |||
| ret = ret.squeeze() | |||
| return ret | |||
| def _dev_tensor(self): | |||
| return _get_dev_tensor(self._handle) | |||
| def _drop(self): | |||
| _drop(self._handle) | |||
| def _swap_in(self): | |||
| _swap_in(self._handle) | |||
| def _swap_out(self): | |||
| _swap_out(self._handle) | |||
| def __repr__(self): | |||
| return "{}({}, device='{}')".format( | |||
| type(self).__qualname__, repr(self.numpy()), self.device | |||
| ) | |||
| def __del__(self): | |||
| if self._handle is not None: | |||
| if self._del_cb: | |||
| self._del_cb() | |||
| delete(self._handle) | |||
| @apply.register() | |||
| def _(op: OpDef, *args: RawTensor): | |||
| outputs = apply_op(op, tuple(i._handle for i in args)) | |||
| return tuple(map(RawTensor, outputs)) | |||
| @apply.register() | |||
| def _(op: Const, *args: RawTensor): | |||
| dtype = op.dtype | |||
| device = as_device(op.device).to_c() | |||
| return (as_raw_tensor(op.value, dtype=dtype, device=device),) | |||
| @functools.singledispatch | |||
| def as_raw_tensor(obj, dtype=None, device=None): | |||
| obj = np.asarray(obj, dtype=dtype) | |||
| if obj.dtype == np.float64: | |||
| obj = obj.astype(np.float32) | |||
| if obj.dtype == np.int64: | |||
| obj = obj.astype(np.int32) | |||
| return as_raw_tensor(obj, device=device) | |||
| @as_raw_tensor.register(DeviceTensorND) | |||
| def _(data: DeviceTensorND): | |||
| return RawTensor(put(data)) | |||
| @as_raw_tensor.register(np.ndarray) | |||
| def _(array: np.ndarray, dtype=None, device=None): | |||
| device = None if device is None else as_device(device).to_c() | |||
| if 0 in array.strides: | |||
| array = array.squeeze().reshape(array.shape) | |||
| return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0)) | |||
| @as_raw_tensor.register(RawTensor) | |||
| def _(tensor: RawTensor, dtype=None, device=None): | |||
| if dtype is not None: | |||
| dtype = np.dtype(dtype) | |||
| if dtype != tensor.dtype: | |||
| (tensor,) = apply(TypeCvt(dtype=dtype), tensor) | |||
| if device is not None: | |||
| device = as_device(device) | |||
| if device != tensor.device: | |||
| (tensor,) = apply(Copy(comp_node=device.to_c()), tensor) | |||
| return tensor | |||
| @@ -9,14 +9,7 @@ | |||
| from typing import Optional, Tuple | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
| from ..core.autodiff.grad import ( | |||
| Tracer, | |||
| check_backward_allow_noinput, | |||
| get_grad_managers, | |||
| get_op_has_grad_fn, | |||
| tracer_apply, | |||
| ) | |||
| from ..core.autodiff.grad import get_grad_managers | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..device import get_default_device | |||
| from ..tensor import Tensor | |||
| @@ -236,7 +229,7 @@ def remote_recv( | |||
| device = get_default_device() | |||
| # dummy input | |||
| if inp == None: | |||
| inp = tensor([0], device=device) | |||
| inp = Tensor([0], device=device) | |||
| tracer_set = get_client().check_remote_tracer(key) | |||
| for grad_manager in get_grad_managers(): | |||
| if grad_manager.name in tracer_set: | |||
| @@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||
| outputs = apply(op, inp) | |||
| for s, x in zip(shapes, outputs): | |||
| if not s: | |||
| x._isscalar = True | |||
| x.setscalar() | |||
| return outputs | |||
| @@ -10,7 +10,7 @@ | |||
| from typing import Optional, Sequence, Tuple, Union | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import Tensor, apply | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._trace_option import use_symbolic_shape | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import BatchNorm | |||
| @@ -12,10 +12,10 @@ from typing import Dict | |||
| import numpy as np | |||
| from .. import functional as F | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.autodiff.grad import Function | |||
| from ..core.ops import builtin | |||
| from ..core.tensor import megbrain_graph | |||
| from ..core.tensor.core import apply | |||
| from ..core.tensor.dtype import _metadata_dict | |||
| from ..tensor import Tensor | |||
| @@ -3,7 +3,7 @@ import sys | |||
| import pytest | |||
| from megengine.core._imperative_rt.imperative import sync | |||
| from megengine.core._imperative_rt.core2 import sync | |||
| sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | |||
| @@ -4,7 +4,6 @@ import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.optimizer as optimizer | |||
| from megengine import Parameter, tensor | |||
| from megengine.core.tensor.raw_tensor import RawTensor | |||
| from megengine.module import Module | |||
| @@ -13,7 +13,6 @@ import pytest | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| from megengine.core.ops import builtin as ops | |||
| from megengine.core.tensor.core import apply | |||
| from megengine.core.tensor.dtype import ( | |||
| _metadata_dict, | |||
| convert_from_qint4, | |||
| @@ -1,58 +0,0 @@ | |||
| from megengine.core.tensor.multipledispatch import Dispatcher | |||
| def test_register_many(): | |||
| f = Dispatcher("f") | |||
| log = [] | |||
| @f.register() | |||
| def _(x: int): | |||
| log.append("a") | |||
| return log[-1] | |||
| @f.register() | |||
| def _(x: int): | |||
| log.append("b") | |||
| return log[-1] | |||
| assert f(0) == "b" | |||
| assert log == ["b"] | |||
| def test_return_not_implemented(): | |||
| f = Dispatcher("f") | |||
| log = [] | |||
| @f.register() | |||
| def _(x: int): | |||
| log.append("a") | |||
| return log[-1] | |||
| @f.register() | |||
| def _(x: int): | |||
| log.append("b") | |||
| return NotImplemented | |||
| assert f(0) == "a" | |||
| assert log == ["b", "a"] | |||
| def test_super(): | |||
| f = Dispatcher("f") | |||
| log = [] | |||
| @f.register() | |||
| def _(x: int): | |||
| log.append("a") | |||
| return log[-1] | |||
| @f.register() | |||
| def _(x: int): | |||
| log.append("b") | |||
| return f.super(x) | |||
| assert f(0) == "a" | |||
| assert log == ["b", "a"] | |||