GitOrigin-RevId: a7c25a4302
tags/v1.0.0-rc1
| @@ -209,6 +209,44 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |||||
| ********************************************************************************************************************************* | |||||
| multipledispatch | |||||
| -------------------------------------------------------------------- | |||||
| Copyright (c) 2014 Matthew Rocklin | |||||
| All rights reserved. | |||||
| Redistribution and use in source and binary forms, with or without | |||||
| modification, are permitted provided that the following conditions are met: | |||||
| a. Redistributions of source code must retain the above copyright notice, | |||||
| this list of conditions and the following disclaimer. | |||||
| b. Redistributions in binary form must reproduce the above copyright | |||||
| notice, this list of conditions and the following disclaimer in the | |||||
| documentation and/or other materials provided with the distribution. | |||||
| c. Neither the name of multipledispatch nor the names of its contributors | |||||
| may be used to endorse or promote products derived from this software | |||||
| without specific prior written permission. | |||||
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||||
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||||
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||||
| ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | |||||
| ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |||||
| DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |||||
| SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |||||
| CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | |||||
| LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | |||||
| OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | |||||
| DAMAGE. | |||||
| ********************************************************************************************************************************* | |||||
| ********************************************************************************************************************************* | ********************************************************************************************************************************* | ||||
| Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-party Components therein: | Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-party Components therein: | ||||
| -------------------------------------------------------------------- | -------------------------------------------------------------------- | ||||
| @@ -343,7 +343,7 @@ def default_has_grad_fn(opnode, reached): | |||||
| return False | return False | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | ||||
| args = tuple(i if isinstance(i, Tracer) else None for i in args) | args = tuple(i if isinstance(i, Tracer) else None for i in args) | ||||
| input_requires_grad = list(map(bool, args)) | input_requires_grad = list(map(bool, args)) | ||||
| @@ -385,6 +385,6 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||||
| return tuple(outputs) | return tuple(outputs) | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: Const, *_: typing.Optional[Tracer]): | def _(op: Const, *_: typing.Optional[Tracer]): | ||||
| return None | return None | ||||
| @@ -19,7 +19,7 @@ from .._internal.helper import PodOpVisitor | |||||
| OpBase.register(OpDef) | OpBase.register(OpDef) | ||||
| # forward to apply(OpDef, ...) | # forward to apply(OpDef, ...) | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): | ||||
| return apply(op.to_c(), *args) | return apply(op.to_c(), *args) | ||||
| @@ -13,7 +13,7 @@ import sys | |||||
| import typing | import typing | ||||
| from abc import ABC | from abc import ABC | ||||
| import multipledispatch | |||||
| from .multipledispatch import Dispatcher | |||||
| class OpBase(ABC): | class OpBase(ABC): | ||||
| @@ -29,84 +29,17 @@ class TensorWrapperBase: | |||||
| pass | pass | ||||
| class Dispatcher(multipledispatch.Dispatcher): | |||||
| def add(self, f, g=None): | |||||
| if g is None: | |||||
| super().add(get_signature(f), f) | |||||
| else: | |||||
| super().add(f, g) | |||||
| return f | |||||
| def __get__(self, instance, owner=None): | |||||
| if instance is not None: | |||||
| return self | |||||
| return functools.partial(self, instance) | |||||
| if sys.version_info < (3, 6): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing.UnionMeta: | |||||
| return | |||||
| return ann.__union_params__ | |||||
| elif sys.version_info < (3, 7): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing._Union: | |||||
| return | |||||
| return ann.__args__ | |||||
| elif sys.version_info < (3, 8): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing._GenericAlias: | |||||
| if type(ann) is not typing.Union: | |||||
| return | |||||
| else: | |||||
| if ann.__origin__ is not typing.Union: | |||||
| return | |||||
| return ann.__args__ | |||||
| else: | |||||
| def parse_union(ann): | |||||
| if typing.get_origin(ann) is not typing.Union: | |||||
| return | |||||
| return typing.get_args(ann) | |||||
| def get_signature(function, op_type=None): | |||||
| sig = inspect.signature(function) | |||||
| types = [] | |||||
| for p in sig.parameters.values(): | |||||
| ann = p.annotation | |||||
| ann = parse_union(ann) or ann | |||||
| if p.kind in ( | |||||
| inspect.Parameter.POSITIONAL_ONLY, | |||||
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||||
| ): | |||||
| types.append(ann) | |||||
| if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||||
| types.append([ann]) | |||||
| return tuple(types) | |||||
| apply = Dispatcher("apply") | apply = Dispatcher("apply") | ||||
| OpBase.apply = apply | OpBase.apply = apply | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: OpBase, *args: TensorBase): | def _(op: OpBase, *args: TensorBase): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: OpBase, *args: TensorWrapperBase): | def _(op: OpBase, *args: TensorWrapperBase): | ||||
| assert args | assert args | ||||
| Wrapper = type(args[0]) | Wrapper = type(args[0]) | ||||
| @@ -102,7 +102,7 @@ class Function: | |||||
| Function.apply = Function.__call__ | Function.apply = Function.__call__ | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: Function, *args: TensorWrapperBase): | def _(op: Function, *args: TensorWrapperBase): | ||||
| assert args | assert args | ||||
| Wrapper = type(args[0]) | Wrapper = type(args[0]) | ||||
| @@ -148,11 +148,11 @@ def _(op: Function, *args: TensorWrapperBase): | |||||
| return tuple(map(Wrapper, outputs)) | return tuple(map(Wrapper, outputs)) | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: Function, *args: Tensor): | def _(op: Function, *args: Tensor): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: Function, *args: RawTensor): | def _(op: Function, *args: RawTensor): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -111,7 +111,7 @@ def _unwrap(x): | |||||
| return x._node | return x._node | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: OpDef, *args: VarNode): | def _(op: OpDef, *args: VarNode): | ||||
| outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | ||||
| return _wrap(outputs) | return _wrap(outputs) | ||||
| @@ -0,0 +1,10 @@ | |||||
| # This directory is a fork of multipledispatch. | |||||
| # | |||||
| # Repo: https://github.com/mrocklin/multipledispatch | |||||
| # Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834 | |||||
| # Authors: Matthew Rocklin et al. | |||||
| # | |||||
| # Refer to ACKNOWLEDGEMENT for copyright and liscense information | |||||
| from .core import dispatch | |||||
| from .dispatcher import Dispatcher | |||||
| @@ -0,0 +1,121 @@ | |||||
| from .utils import _toposort, groupby | |||||
| from .variadic import isvariadic | |||||
| class AmbiguityWarning(Warning): | |||||
| pass | |||||
| def supercedes(a, b): | |||||
| """ A is consistent and strictly more specific than B """ | |||||
| if len(a) < len(b): | |||||
| # only case is if a is empty and b is variadic | |||||
| return not a and len(b) == 1 and isvariadic(b[-1]) | |||||
| elif len(a) == len(b): | |||||
| return all(map(issubclass, a, b)) | |||||
| else: | |||||
| # len(a) > len(b) | |||||
| p1 = 0 | |||||
| p2 = 0 | |||||
| while p1 < len(a) and p2 < len(b): | |||||
| cur_a = a[p1] | |||||
| cur_b = b[p2] | |||||
| if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||||
| if not issubclass(cur_a, cur_b): | |||||
| return False | |||||
| p1 += 1 | |||||
| p2 += 1 | |||||
| elif isvariadic(cur_a): | |||||
| assert p1 == len(a) - 1 | |||||
| return p2 == len(b) - 1 and issubclass(cur_a, cur_b) | |||||
| elif isvariadic(cur_b): | |||||
| assert p2 == len(b) - 1 | |||||
| if not issubclass(cur_a, cur_b): | |||||
| return False | |||||
| p1 += 1 | |||||
| return p2 == len(b) - 1 and p1 == len(a) | |||||
| def consistent(a, b): | |||||
| """ It is possible for an argument list to satisfy both A and B """ | |||||
| # Need to check for empty args | |||||
| if not a: | |||||
| return not b or isvariadic(b[0]) | |||||
| if not b: | |||||
| return not a or isvariadic(a[0]) | |||||
| # Non-empty args check for mutual subclasses | |||||
| if len(a) == len(b): | |||||
| return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) | |||||
| else: | |||||
| p1 = 0 | |||||
| p2 = 0 | |||||
| while p1 < len(a) and p2 < len(b): | |||||
| cur_a = a[p1] | |||||
| cur_b = b[p2] | |||||
| if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): | |||||
| return False | |||||
| if not (isvariadic(cur_a) or isvariadic(cur_b)): | |||||
| p1 += 1 | |||||
| p2 += 1 | |||||
| elif isvariadic(cur_a): | |||||
| p2 += 1 | |||||
| elif isvariadic(cur_b): | |||||
| p1 += 1 | |||||
| # We only need to check for variadic ends | |||||
| # Variadic types are guaranteed to be the last element | |||||
| return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a) | |||||
| def ambiguous(a, b): | |||||
| """ A is consistent with B but neither is strictly more specific """ | |||||
| return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) | |||||
| def ambiguities(signatures): | |||||
| """ All signature pairs such that A is ambiguous with B """ | |||||
| signatures = list(map(tuple, signatures)) | |||||
| return set( | |||||
| (a, b) | |||||
| for a in signatures | |||||
| for b in signatures | |||||
| if hash(a) < hash(b) | |||||
| and ambiguous(a, b) | |||||
| and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) | |||||
| ) | |||||
| def super_signature(signatures): | |||||
| """ A signature that would break ambiguities """ | |||||
| n = len(signatures[0]) | |||||
| assert all(len(s) == n for s in signatures) | |||||
| return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)] | |||||
| def edge(a, b, tie_breaker=hash): | |||||
| """ A should be checked before B | |||||
| Tie broken by tie_breaker, defaults to ``hash`` | |||||
| """ | |||||
| # A either supercedes B and B does not supercede A or if B does then call | |||||
| # tie_breaker | |||||
| return supercedes(a, b) and ( | |||||
| not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) | |||||
| ) | |||||
| def ordering(signatures): | |||||
| """ A sane ordering of signatures to check, first to last | |||||
| Topoological sort of edges as given by ``edge`` and ``supercedes`` | |||||
| """ | |||||
| signatures = list(map(tuple, signatures)) | |||||
| edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] | |||||
| edges = groupby(lambda x: x[0], edges) | |||||
| for s in signatures: | |||||
| if s not in edges: | |||||
| edges[s] = [] | |||||
| edges = dict((k, [b for a, b in v]) for k, v in edges.items()) | |||||
| return _toposort(edges) | |||||
| @@ -0,0 +1,88 @@ | |||||
| import inspect | |||||
| import sys | |||||
| from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn | |||||
| global_namespace = dict() | |||||
| def dispatch(*types, **kwargs): | |||||
| """ Dispatch function on the types of the inputs | |||||
| Supports dispatch on all non-keyword arguments. | |||||
| Collects implementations based on the function name. Ignores namespaces. | |||||
| If ambiguous type signatures occur a warning is raised when the function is | |||||
| defined suggesting the additional method to break the ambiguity. | |||||
| Examples | |||||
| -------- | |||||
| >>> @dispatch(int) | |||||
| ... def f(x): | |||||
| ... return x + 1 | |||||
| >>> @dispatch(float) | |||||
| ... def f(x): | |||||
| ... return x - 1 | |||||
| >>> f(3) | |||||
| 4 | |||||
| >>> f(3.0) | |||||
| 2.0 | |||||
| Specify an isolated namespace with the namespace keyword argument | |||||
| >>> my_namespace = dict() | |||||
| >>> @dispatch(int, namespace=my_namespace) | |||||
| ... def foo(x): | |||||
| ... return x + 1 | |||||
| Dispatch on instance methods within classes | |||||
| >>> class MyClass(object): | |||||
| ... @dispatch(list) | |||||
| ... def __init__(self, data): | |||||
| ... self.data = data | |||||
| ... @dispatch(int) | |||||
| ... def __init__(self, datum): | |||||
| ... self.data = [datum] | |||||
| """ | |||||
| namespace = kwargs.get("namespace", global_namespace) | |||||
| types = tuple(types) | |||||
| def _df(func): | |||||
| name = func.__name__ | |||||
| if ismethod(func): | |||||
| dispatcher = inspect.currentframe().f_back.f_locals.get( | |||||
| name, MethodDispatcher(name), | |||||
| ) | |||||
| else: | |||||
| if name not in namespace: | |||||
| namespace[name] = Dispatcher(name) | |||||
| dispatcher = namespace[name] | |||||
| dispatcher.add(types, func) | |||||
| return dispatcher | |||||
| return _df | |||||
| def ismethod(func): | |||||
| """ Is func a method? | |||||
| Note that this has to work as the method is defined but before the class is | |||||
| defined. At this stage methods look like functions. | |||||
| """ | |||||
| if hasattr(inspect, "signature"): | |||||
| signature = inspect.signature(func) | |||||
| return signature.parameters.get("self", None) is not None | |||||
| else: | |||||
| if sys.version_info.major < 3: | |||||
| spec = inspect.getargspec(func) | |||||
| else: | |||||
| spec = inspect.getfullargspec(func) | |||||
| return spec and spec.args and spec.args[0] == "self" | |||||
| @@ -0,0 +1,401 @@ | |||||
| import copy | |||||
| import inspect | |||||
| import itertools as itl | |||||
| from warnings import warn | |||||
| from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher | |||||
| from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature | |||||
| from .utils import expand_tuples, parse_union | |||||
| from .variadic import Variadic, isvariadic | |||||
| def ambiguity_warn(dispatcher, ambiguities): | |||||
| """ Raise warning when ambiguity is detected | |||||
| Parameters | |||||
| ---------- | |||||
| dispatcher : Dispatcher | |||||
| The dispatcher on which the ambiguity was detected | |||||
| ambiguities : set | |||||
| Set of type signature pairs that are ambiguous within this dispatcher | |||||
| See Also: | |||||
| Dispatcher.add | |||||
| warning_text | |||||
| """ | |||||
| warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) | |||||
| def variadic_signature_matches_iter(types, full_signature): | |||||
| """Check if a set of input types matches a variadic signature. | |||||
| Notes | |||||
| ----- | |||||
| The algorithm is as follows: | |||||
| Initialize the current signature to the first in the sequence | |||||
| For each type in `types`: | |||||
| If the current signature is variadic | |||||
| If the type matches the signature | |||||
| yield True | |||||
| Else | |||||
| Try to get the next signature | |||||
| If no signatures are left we can't possibly have a match | |||||
| so yield False | |||||
| Else | |||||
| yield True if the type matches the current signature | |||||
| Get the next signature | |||||
| """ | |||||
| sigiter = iter(full_signature) | |||||
| sig = next(sigiter) | |||||
| for typ in types: | |||||
| matches = issubclass(typ, sig) | |||||
| yield matches | |||||
| if not isvariadic(sig): | |||||
| # we're not matching a variadic argument, so move to the next | |||||
| # element in the signature | |||||
| sig = next(sigiter) | |||||
| else: | |||||
| try: | |||||
| sig = next(sigiter) | |||||
| except StopIteration: | |||||
| assert isvariadic(sig) | |||||
| yield True | |||||
| else: | |||||
| # We have signature items left over, so all of our arguments | |||||
| # haven't matched | |||||
| yield False | |||||
| def variadic_signature_matches(types, full_signature): | |||||
| # No arguments always matches a variadic signature | |||||
| assert full_signature | |||||
| return all(variadic_signature_matches_iter(types, full_signature)) | |||||
| def get_func_signature(function): | |||||
| sig = inspect.signature(function) | |||||
| types = [] | |||||
| for p in sig.parameters.values(): | |||||
| ann = p.annotation | |||||
| ann = parse_union(ann) or ann | |||||
| if p.kind in ( | |||||
| inspect.Parameter.POSITIONAL_ONLY, | |||||
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |||||
| ): | |||||
| types.append(ann) | |||||
| if p.kind == inspect.Parameter.VAR_POSITIONAL: | |||||
| types.append([ann]) | |||||
| return tuple(types) | |||||
| class Frame: | |||||
| __slots__ = "args", "types", "mro", "mro_offset" | |||||
| class Dispatcher(CDispatcher): | |||||
| """ Dispatch methods based on type signature | |||||
| Use ``dispatch`` to add implementations | |||||
| Examples | |||||
| -------- | |||||
| >>> from multipledispatch import dispatch | |||||
| >>> @dispatch(int) | |||||
| ... def f(x): | |||||
| ... return x + 1 | |||||
| >>> @dispatch(float) | |||||
| ... def f(x): | |||||
| ... return x - 1 | |||||
| >>> f(3) | |||||
| 4 | |||||
| >>> f(3.0) | |||||
| 2.0 | |||||
| """ | |||||
| __slots__ = "__name__", "name", "funcs", "_ordering", "doc" | |||||
| def __init__(self, name, doc=None): | |||||
| self.name = self.__name__ = name | |||||
| self.funcs = {} | |||||
| self.doc = doc | |||||
| def register(self, *types, **kwargs): | |||||
| """ register dispatcher with new implementation | |||||
| >>> f = Dispatcher('f') | |||||
| >>> @f.register(int) | |||||
| ... def inc(x): | |||||
| ... return x + 1 | |||||
| >>> @f.register(float) | |||||
| ... def dec(x): | |||||
| ... return x - 1 | |||||
| >>> @f.register(list) | |||||
| ... @f.register(tuple) | |||||
| ... def reverse(x): | |||||
| ... return x[::-1] | |||||
| >>> f(1) | |||||
| 2 | |||||
| >>> f(1.0) | |||||
| 0.0 | |||||
| >>> f([1, 2, 3]) | |||||
| [3, 2, 1] | |||||
| """ | |||||
| def _df(func): | |||||
| self.add(types, func, **kwargs) | |||||
| return func | |||||
| return _df | |||||
| def add(self, signature, func): | |||||
| """ Add new types/method pair to dispatcher | |||||
| >>> D = Dispatcher('add') | |||||
| >>> D.add((int, int), lambda x, y: x + y) | |||||
| >>> D.add((float, float), lambda x, y: x + y) | |||||
| >>> D(1, 2) | |||||
| 3 | |||||
| >>> D(1, 2.0) | |||||
| Traceback (most recent call last): | |||||
| ... | |||||
| NotImplementedError: Could not find signature for add: <int, float> | |||||
| When ``add`` detects a warning it calls the ``on_ambiguity`` callback | |||||
| with a dispatcher/itself, and a set of ambiguous type signature pairs | |||||
| as inputs. See ``ambiguity_warn`` for an example. | |||||
| """ | |||||
| # Handle annotations | |||||
| if not signature: | |||||
| signature = get_func_signature(func) | |||||
| # Handle union types | |||||
| if any(isinstance(typ, tuple) for typ in signature): | |||||
| for typs in expand_tuples(signature): | |||||
| self.add(typs, func) | |||||
| return | |||||
| new_signature = [] | |||||
| for index, typ in enumerate(signature, start=1): | |||||
| if not isinstance(typ, (type, list)): | |||||
| str_sig = ", ".join( | |||||
| c.__name__ if isinstance(c, type) else str(c) for c in signature | |||||
| ) | |||||
| raise TypeError( | |||||
| "Tried to dispatch on non-type: %s\n" | |||||
| "In signature: <%s>\n" | |||||
| "In function: %s" % (typ, str_sig, self.name) | |||||
| ) | |||||
| # handle variadic signatures | |||||
| if isinstance(typ, list): | |||||
| if index != len(signature): | |||||
| raise TypeError("Variadic signature must be the last element") | |||||
| if len(typ) != 1: | |||||
| raise TypeError( | |||||
| "Variadic signature must contain exactly one element. " | |||||
| "To use a variadic union type place the desired types " | |||||
| "inside of a tuple, e.g., [(int, str)]" | |||||
| ) | |||||
| new_signature.append(Variadic[typ[0]]) | |||||
| else: | |||||
| new_signature.append(typ) | |||||
| l = self.funcs.setdefault(tuple(new_signature), []) | |||||
| for i in l: | |||||
| if i is func: | |||||
| raise ValueError("already registered") | |||||
| l.append(func) | |||||
| self.enable(func) | |||||
| self.clear_cache() | |||||
| try: | |||||
| del self._ordering | |||||
| except AttributeError: | |||||
| pass | |||||
| @property | |||||
| def ordering(self): | |||||
| try: | |||||
| return self._ordering | |||||
| except AttributeError: | |||||
| return self.reorder() | |||||
| def reorder(self, on_ambiguity=ambiguity_warn): | |||||
| self._ordering = od = ordering(self.funcs) | |||||
| amb = ambiguities(self.funcs) | |||||
| if amb: | |||||
| on_ambiguity(self, amb) | |||||
| return od | |||||
| def __str__(self): | |||||
| return "<dispatched %s>" % self.name | |||||
| __repr__ = __str__ | |||||
| def dispatch(self, *types): | |||||
| """Deterimine appropriate implementation for this type signature | |||||
| This method is internal. Users should call this object as a function. | |||||
| Implementation resolution occurs within the ``__call__`` method. | |||||
| >>> from multipledispatch import dispatch | |||||
| >>> @dispatch(int) | |||||
| ... def inc(x): | |||||
| ... return x + 1 | |||||
| >>> implementation = inc.dispatch(int) | |||||
| >>> implementation(3) | |||||
| 4 | |||||
| >>> print(inc.dispatch(float)) | |||||
| None | |||||
| See Also: | |||||
| ``multipledispatch.conflict`` - module to determine resolution order | |||||
| """ | |||||
| if types in self.funcs: | |||||
| return self.funcs[types][-1] | |||||
| for f in self.dispatch_iter(*types): | |||||
| return f | |||||
| def dispatch_iter(self, *types): | |||||
| n = len(types) | |||||
| for signature in self.ordering: | |||||
| if ( | |||||
| len(signature) == n | |||||
| and all(map(issubclass, types, signature)) | |||||
| or len(signature) | |||||
| and isvariadic(signature[-1]) | |||||
| and variadic_signature_matches(types, signature) | |||||
| ): | |||||
| yield from self.funcs[signature][::-1] | |||||
| def __getstate__(self): | |||||
| return {"name": self.name, "funcs": self.funcs} | |||||
| def __setstate__(self, d): | |||||
| self.name = d["name"] | |||||
| self.funcs = d["funcs"] | |||||
| self._ordering = ordering(self.funcs) | |||||
| self._cache = dict() | |||||
| @property | |||||
| def __doc__(self): | |||||
| docs = ["Multiply dispatched method: %s" % self.name] | |||||
| if self.doc: | |||||
| docs.append(self.doc) | |||||
| other = [] | |||||
| for sig in self.ordering[::-1]: | |||||
| funcs = self.funcs[sig] | |||||
| s = "Inputs: <%s>\n" % str_signature(sig) | |||||
| sep = "-" * len(s) + "\n" | |||||
| for i, func in enumerate(funcs): | |||||
| s += sep | |||||
| if len(funcs) > 1: | |||||
| s += "[Handler %d]\n\n" % (i + 1) | |||||
| if i: | |||||
| s += "\n\n" | |||||
| if func.__doc__: | |||||
| s += func.__doc__.strip() | |||||
| else: | |||||
| s += repr(func) + "\n" | |||||
| docs.append(s) | |||||
| return "\n\n".join(docs) | |||||
| def _help(self, *args): | |||||
| return self.dispatch(*map(type, args)).__doc__ | |||||
| def help(self, *args, **kwargs): | |||||
| """ Print docstring for the function corresponding to inputs """ | |||||
| print(self._help(*args)) | |||||
| def _source(self, *args): | |||||
| func = self.dispatch(*map(type, args)) | |||||
| if not func: | |||||
| raise TypeError("No function found") | |||||
| return source(func) | |||||
| def source(self, *args, **kwargs): | |||||
| """ Print source code for the function corresponding to inputs """ | |||||
| print(self._source(*args)) | |||||
| def source(func): | |||||
| s = "File: %s\n\n" % inspect.getsourcefile(func) | |||||
| s = s + inspect.getsource(func) | |||||
| return s | |||||
| class MethodDispatcher(Dispatcher): | |||||
| """ Dispatch methods based on type signature | |||||
| See Also: | |||||
| Dispatcher | |||||
| """ | |||||
| __slots__ = ("obj", "cls") | |||||
| @classmethod | |||||
| def get_func_params(cls, func): | |||||
| if hasattr(inspect, "signature"): | |||||
| sig = inspect.signature(func) | |||||
| return itl.islice(sig.parameters.values(), 1, None) | |||||
| def __get__(self, instance, owner): | |||||
| self.obj = instance | |||||
| self.cls = owner | |||||
| return self | |||||
| def __call__(self, *args, **kwargs): | |||||
| types = tuple([type(arg) for arg in args]) | |||||
| func = self.dispatch(*types) | |||||
| if not func: | |||||
| raise NotImplementedError( | |||||
| "Could not find signature for %s: <%s>" | |||||
| % (self.name, str_signature(types)) | |||||
| ) | |||||
| return func(self.obj, *args, **kwargs) | |||||
| def str_signature(sig): | |||||
| """ String representation of type signature | |||||
| >>> str_signature((int, float)) | |||||
| 'int, float' | |||||
| """ | |||||
| return ", ".join(cls.__name__ for cls in sig) | |||||
| def warning_text(name, amb): | |||||
| """ The text for ambiguity warnings """ | |||||
| text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) | |||||
| text += "The following signatures may result in ambiguous behavior:\n" | |||||
| for pair in amb: | |||||
| text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" | |||||
| text += "\n\nConsider making the following additions:\n\n" | |||||
| text += "\n\n".join( | |||||
| [ | |||||
| "@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name | |||||
| for s in amb | |||||
| ] | |||||
| ) | |||||
| return text | |||||
| @@ -0,0 +1,177 @@ | |||||
| import sys | |||||
| import typing | |||||
| from collections import OrderedDict | |||||
| def raises(err, lamda): | |||||
| try: | |||||
| lamda() | |||||
| return False | |||||
| except err: | |||||
| return True | |||||
| def expand_tuples(L): | |||||
| """ | |||||
| >>> expand_tuples([1, (2, 3)]) | |||||
| [(1, 2), (1, 3)] | |||||
| >>> expand_tuples([1, 2]) | |||||
| [(1, 2)] | |||||
| """ | |||||
| if not L: | |||||
| return [()] | |||||
| elif not isinstance(L[0], tuple): | |||||
| rest = expand_tuples(L[1:]) | |||||
| return [(L[0],) + t for t in rest] | |||||
| else: | |||||
| rest = expand_tuples(L[1:]) | |||||
| return [(item,) + t for t in rest for item in L[0]] | |||||
| # Taken from theano/theano/gof/sched.py | |||||
| # Avoids licensing issues because this was written by Matthew Rocklin | |||||
| def _toposort(edges): | |||||
| """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) | |||||
| inputs: | |||||
| edges - a dict of the form {a: {b, c}} where b and c depend on a | |||||
| outputs: | |||||
| L - an ordered list of nodes that satisfy the dependencies of edges | |||||
| >>> _toposort({1: (2, 3), 2: (3, )}) | |||||
| [1, 2, 3] | |||||
| Closely follows the wikipedia page [2] | |||||
| [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", | |||||
| Communications of the ACM | |||||
| [2] http://en.wikipedia.org/wiki/Toposort#Algorithms | |||||
| """ | |||||
| incoming_edges = reverse_dict(edges) | |||||
| incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) | |||||
| S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) | |||||
| L = [] | |||||
| while S: | |||||
| n, _ = S.popitem() | |||||
| L.append(n) | |||||
| for m in edges.get(n, ()): | |||||
| assert n in incoming_edges[m] | |||||
| incoming_edges[m].remove(n) | |||||
| if not incoming_edges[m]: | |||||
| S[m] = None | |||||
| if any(incoming_edges.get(v, None) for v in edges): | |||||
| raise ValueError("Input has cycles") | |||||
| return L | |||||
| def reverse_dict(d): | |||||
| """Reverses direction of dependence dict | |||||
| >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} | |||||
| >>> reverse_dict(d) # doctest: +SKIP | |||||
| {1: ('a',), 2: ('a', 'b'), 3: ('b',)} | |||||
| :note: dict order are not deterministic. As we iterate on the | |||||
| input dict, it make the output of this function depend on the | |||||
| dict order. So this function output order should be considered | |||||
| as undeterministic. | |||||
| """ | |||||
| result = OrderedDict() | |||||
| for key in d: | |||||
| for val in d[key]: | |||||
| result[val] = result.get(val, tuple()) + (key,) | |||||
| return result | |||||
| # Taken from toolz | |||||
| # Avoids licensing issues because this version was authored by Matthew Rocklin | |||||
| def groupby(func, seq): | |||||
| """ Group a collection by a key function | |||||
| >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] | |||||
| >>> groupby(len, names) # doctest: +SKIP | |||||
| {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} | |||||
| >>> iseven = lambda x: x % 2 == 0 | |||||
| >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP | |||||
| {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} | |||||
| See Also: | |||||
| ``countby`` | |||||
| """ | |||||
| d = OrderedDict() | |||||
| for item in seq: | |||||
| key = func(item) | |||||
| if key not in d: | |||||
| d[key] = list() | |||||
| d[key].append(item) | |||||
| return d | |||||
| def typename(type): | |||||
| """Get the name of `type`. | |||||
| Parameters | |||||
| ---------- | |||||
| type : Union[Type, Tuple[Type]] | |||||
| Returns | |||||
| ------- | |||||
| str | |||||
| The name of `type` or a tuple of the names of the types in `type`. | |||||
| Examples | |||||
| -------- | |||||
| >>> typename(int) | |||||
| 'int' | |||||
| >>> typename((int, float)) | |||||
| '(int, float)' | |||||
| """ | |||||
| try: | |||||
| return type.__name__ | |||||
| except AttributeError: | |||||
| if len(type) == 1: | |||||
| return typename(*type) | |||||
| return "(%s)" % ", ".join(map(typename, type)) | |||||
| # parse typing.Union | |||||
| if sys.version_info < (3, 6): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing.UnionMeta: | |||||
| return | |||||
| return ann.__union_params__ | |||||
| elif sys.version_info < (3, 7): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing._Union: | |||||
| return | |||||
| return ann.__args__ | |||||
| elif sys.version_info < (3, 8): | |||||
| def parse_union(ann): | |||||
| if type(ann) is not typing._GenericAlias: | |||||
| if type(ann) is not typing.Union: | |||||
| return | |||||
| else: | |||||
| if ann.__origin__ is not typing.Union: | |||||
| return | |||||
| return ann.__args__ | |||||
| else: | |||||
| def parse_union(ann): | |||||
| if typing.get_origin(ann) is not typing.Union: | |||||
| return | |||||
| return typing.get_args(ann) | |||||
| @@ -0,0 +1,95 @@ | |||||
| from .utils import typename | |||||
| class VariadicSignatureType(type): | |||||
| # checking if subclass is a subclass of self | |||||
| def __subclasscheck__(self, subclass): | |||||
| other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) | |||||
| return subclass is self or all( | |||||
| issubclass(other, self.variadic_type) for other in other_type | |||||
| ) | |||||
| def __eq__(self, other): | |||||
| """ | |||||
| Return True if other has the same variadic type | |||||
| Parameters | |||||
| ---------- | |||||
| other : object (type) | |||||
| The object (type) to check | |||||
| Returns | |||||
| ------- | |||||
| bool | |||||
| Whether or not `other` is equal to `self` | |||||
| """ | |||||
| return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type) | |||||
| def __hash__(self): | |||||
| return hash((type(self), frozenset(self.variadic_type))) | |||||
| def isvariadic(obj): | |||||
| """Check whether the type `obj` is variadic. | |||||
| Parameters | |||||
| ---------- | |||||
| obj : type | |||||
| The type to check | |||||
| Returns | |||||
| ------- | |||||
| bool | |||||
| Whether or not `obj` is variadic | |||||
| Examples | |||||
| -------- | |||||
| >>> isvariadic(int) | |||||
| False | |||||
| >>> isvariadic(Variadic[int]) | |||||
| True | |||||
| """ | |||||
| return isinstance(obj, VariadicSignatureType) | |||||
| class VariadicSignatureMeta(type): | |||||
| """A metaclass that overrides ``__getitem__`` on the class. This is used to | |||||
| generate a new type for Variadic signatures. See the Variadic class for | |||||
| examples of how this behaves. | |||||
| """ | |||||
| def __getitem__(self, variadic_type): | |||||
| if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): | |||||
| raise ValueError( | |||||
| "Variadic types must be type or tuple of types" | |||||
| " (Variadic[int] or Variadic[(int, float)]" | |||||
| ) | |||||
| if not isinstance(variadic_type, tuple): | |||||
| variadic_type = (variadic_type,) | |||||
| return VariadicSignatureType( | |||||
| "Variadic[%s]" % typename(variadic_type), | |||||
| (), | |||||
| dict(variadic_type=variadic_type, __slots__=()), | |||||
| ) | |||||
| class Variadic(metaclass=VariadicSignatureMeta): | |||||
| """A class whose getitem method can be used to generate a new type | |||||
| representing a specific variadic signature. | |||||
| Examples | |||||
| -------- | |||||
| >>> Variadic[int] # any number of int arguments | |||||
| <class 'multipledispatch.variadic.Variadic[int]'> | |||||
| >>> Variadic[(int, str)] # any number of one of int or str arguments | |||||
| <class 'multipledispatch.variadic.Variadic[(int, str)]'> | |||||
| >>> issubclass(int, Variadic[int]) | |||||
| True | |||||
| >>> issubclass(int, Variadic[(int, str)]) | |||||
| True | |||||
| >>> issubclass(str, Variadic[(int, str)]) | |||||
| True | |||||
| >>> issubclass(float, Variadic[(int, str)]) | |||||
| False | |||||
| """ | |||||
| @@ -66,13 +66,13 @@ class RawTensor(TensorBase): | |||||
| delete(self._handle) | delete(self._handle) | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: OpDef, *args: RawTensor): | def _(op: OpDef, *args: RawTensor): | ||||
| outputs = apply_op(op, tuple(i._handle for i in args)) | outputs = apply_op(op, tuple(i._handle for i in args)) | ||||
| return tuple(map(RawTensor, outputs)) | return tuple(map(RawTensor, outputs)) | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: Const, *args: RawTensor): | def _(op: Const, *args: RawTensor): | ||||
| dtype = op.dtype | dtype = op.dtype | ||||
| device = as_device(op.device).to_c() | device = as_device(op.device).to_c() | ||||
| @@ -79,7 +79,7 @@ def get_context(): | |||||
| return _context | return _context | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def tensor_apply(op: OpBase, *args: Tensor): | def tensor_apply(op: OpBase, *args: Tensor): | ||||
| data = tuple(i._data if isinstance(i, Tensor) else i for i in args) | data = tuple(i._data if isinstance(i, Tensor) else i for i in args) | ||||
| # type(Tensor._data) is RawTensor | # type(Tensor._data) is RawTensor | ||||
| @@ -46,7 +46,7 @@ __all__ = [ | |||||
| ] | ] | ||||
| @apply.add | |||||
| @apply.register() | |||||
| def _(op: RemoteSend, *args: Tensor): | def _(op: RemoteSend, *args: Tensor): | ||||
| ret = tensor_apply(op, *args) | ret = tensor_apply(op, *args) | ||||
| @@ -1,5 +1,4 @@ | |||||
| numpy>=1.18 | numpy>=1.18 | ||||
| multipledispatch==0.6.0 | |||||
| opencv-python | opencv-python | ||||
| pyarrow | pyarrow | ||||
| requests | requests | ||||
| @@ -0,0 +1,180 @@ | |||||
| #include "./dispatcher.h" | |||||
| #include "./pyext17.h" | |||||
| #include "megbrain/utils/hash.h" | |||||
| #include "megbrain/utils/small_vector.h" | |||||
| #include <unordered_map> | |||||
| #include <structmember.h> | |||||
| namespace py = pybind11; | |||||
| namespace pyx = pyext17; | |||||
| namespace { | |||||
| struct Handler { | |||||
| PyObject* func; // borrowed | |||||
| bool enabled; | |||||
| Handler() = default; | |||||
| Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {} | |||||
| }; | |||||
| using FastSig = mgb::SmallVector<void*, 8>; | |||||
| using MRO = std::vector<Handler*>; | |||||
| struct Frame { | |||||
| MRO* mro; | |||||
| size_t mro_offset; | |||||
| Frame() = default; | |||||
| Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {} | |||||
| }; | |||||
| struct FastSigHash { | |||||
| size_t operator()(const FastSig& sig) const { | |||||
| auto* ptr = &sig.front(); | |||||
| return mgb::XXHash() | |||||
| .update(ptr, sig.size() * sizeof(FastSig::value_type)) | |||||
| .digest(); | |||||
| } | |||||
| }; | |||||
| struct ObjectIdHash : std::hash<void*> { | |||||
| size_t operator()(const py::handle& h) const { | |||||
| return std::hash<void*>::operator()(h.ptr()); | |||||
| } | |||||
| }; | |||||
| struct Dispatcher { | |||||
| std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache; | |||||
| std::vector<Frame> stack; | |||||
| std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry; | |||||
| inline py::handle self() { | |||||
| return pyx::wrap<Dispatcher>::pycast(this); | |||||
| } | |||||
| bool prepare_call(PyObject*const* args, Py_ssize_t nargs) { | |||||
| FastSig sig(nargs); | |||||
| for (Py_ssize_t i = 0; i < nargs; ++i) { | |||||
| sig[i] = Py_TYPE(args[i]); | |||||
| } | |||||
| auto it = cache.find(sig); | |||||
| if (it == cache.end()) { | |||||
| if (auto mro = resolve(sig)) { | |||||
| it = cache.emplace(std::move(sig), std::move(mro)).first; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| stack.emplace_back(it->second.get()); | |||||
| return true; | |||||
| } | |||||
| template<typename T> | |||||
| PyObject* do_call(T&& caller) { | |||||
| auto& frame = stack.back(); | |||||
| auto& mro = *frame.mro; | |||||
| auto& i = frame.mro_offset; | |||||
| for (; i < mro.size(); ++i) { | |||||
| if (mro[i]->enabled) { | |||||
| auto ret = caller(mro[i]->func); | |||||
| if (ret != Py_NotImplemented) { | |||||
| stack.pop_back(); | |||||
| return ret; | |||||
| } | |||||
| Py_DECREF(ret); | |||||
| } | |||||
| } | |||||
| PyErr_SetString(PyExc_NotImplementedError, "mro exhausted"); | |||||
| stack.pop_back(); | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<MRO> resolve(const FastSig& sig) { | |||||
| try { | |||||
| py::tuple args(sig.size()); | |||||
| for (size_t i = 0; i < sig.size(); ++i) { | |||||
| args[i] = (PyObject*)sig[i]; | |||||
| } | |||||
| auto mro_iter = self().attr("dispatch_iter")(*args); | |||||
| auto ret = std::make_unique<MRO>(); | |||||
| for (auto i : mro_iter) { | |||||
| auto it = registry.find(py::reinterpret_borrow<py::object>(i)); | |||||
| if (it == registry.end()) { | |||||
| PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function"); | |||||
| return nullptr; | |||||
| } | |||||
| ret->push_back(it->second.get()); | |||||
| } | |||||
| return ret; | |||||
| } catch (py::error_already_set& e) { | |||||
| e.restore(); | |||||
| } catch (std::runtime_error& e) { | |||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| public: | |||||
| static constexpr auto tp_name = "Dispatcher"; | |||||
| PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) { | |||||
| if (!prepare_call(args, nargs)) return nullptr; | |||||
| return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);}); | |||||
| } | |||||
| PyObject* tp_call(PyObject* args, PyObject* kwargs) { | |||||
| if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr; | |||||
| return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);}); | |||||
| } | |||||
| PyObject* super(PyObject*const* args, Py_ssize_t nargs) { | |||||
| if (stack.empty()) { | |||||
| PyErr_SetString(PyExc_RuntimeError, "super called at top level"); | |||||
| return nullptr; | |||||
| } | |||||
| stack.emplace_back(stack.back()).mro_offset++; | |||||
| return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);}); | |||||
| } | |||||
| void enable(PyObject* func) { | |||||
| auto obj = py::reinterpret_borrow<py::object>(func); | |||||
| auto it = registry.find(obj); | |||||
| if (it != registry.end()) { | |||||
| it->second->enabled = true; | |||||
| } else { | |||||
| registry.emplace(std::move(obj), std::make_unique<Handler>(func)); | |||||
| } | |||||
| } | |||||
| PyObject* disable(PyObject* func) { | |||||
| auto obj = py::reinterpret_borrow<py::object>(func); | |||||
| auto it = registry.find(obj); | |||||
| if (it == registry.end()) { | |||||
| PyErr_SetString(PyExc_ValueError, "function not registered"); | |||||
| return nullptr; | |||||
| } else { | |||||
| it->second->enabled = false; | |||||
| } | |||||
| Py_RETURN_NONE; | |||||
| } | |||||
| void clear_cache() { | |||||
| cache.clear(); | |||||
| } | |||||
| }; | |||||
| } // namespace | |||||
| void init_dispatcher(py::module m) { | |||||
| auto* dispatcher_type = pyx::wrap<Dispatcher>::type() | |||||
| .def<&Dispatcher::enable>("enable") | |||||
| .def<&Dispatcher::disable>("disable") | |||||
| .def<&Dispatcher::clear_cache>("clear_cache") | |||||
| .def<&Dispatcher::tp_vectorcall>("call") | |||||
| .def<&Dispatcher::super>("super") | |||||
| .finalize(); | |||||
| if (!dispatcher_type) throw py::error_already_set(); | |||||
| m.attr("Dispatcher") = dispatcher_type; | |||||
| } | |||||
| @@ -0,0 +1,5 @@ | |||||
| #pragma once | |||||
| #include <pybind11/pybind11.h> | |||||
| void init_dispatcher(pybind11::module); | |||||
| @@ -21,6 +21,8 @@ | |||||
| #include "./graph_rt.h" | #include "./graph_rt.h" | ||||
| #include "./ops.h" | #include "./ops.h" | ||||
| #include "./dispatcher.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| #ifndef MODULE_NAME | #ifndef MODULE_NAME | ||||
| @@ -63,4 +65,6 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||||
| from .graph import * | from .graph import * | ||||
| )", | )", | ||||
| py::getattr(m, "__dict__")); | py::getattr(m, "__dict__")); | ||||
| init_dispatcher(submodule(m, "dispatcher")); | |||||
| } | } | ||||
| @@ -0,0 +1,270 @@ | |||||
| #pragma once | |||||
| #include <stdexcept> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <Python.h> | |||||
| namespace pyext17 { | |||||
| #ifdef METH_FASTCALL | |||||
| constexpr bool has_fastcall = true; | |||||
| #else | |||||
| constexpr bool has_fastcall = false; | |||||
| #endif | |||||
| template<typename... Args> | |||||
| struct invocable_with { | |||||
| template<typename T> | |||||
| constexpr bool operator()(T&& lmb) { | |||||
| return std::is_invocable_v<T, Args...>; | |||||
| } | |||||
| }; | |||||
| #define HAS_MEMBER_TYPE(T, U) invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {}) | |||||
| #define HAS_MEMBER(T, m) invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {}) | |||||
| inline PyObject* cvt_retval(PyObject* rv) { | |||||
| return rv; | |||||
| } | |||||
| #define CVT_RET_PYOBJ(...) \ | |||||
| if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \ | |||||
| __VA_ARGS__; \ | |||||
| Py_RETURN_NONE; \ | |||||
| } else { \ | |||||
| return cvt_retval(__VA_ARGS__); \ | |||||
| } | |||||
| template <typename T> | |||||
| struct wrap { | |||||
| private: | |||||
| typedef wrap<T> wrap_t; | |||||
| public: | |||||
| PyObject_HEAD | |||||
| std::aligned_storage_t<sizeof(T), alignof(T)> storage; | |||||
| inline T* inst() { | |||||
| return reinterpret_cast<T*>(&storage); | |||||
| } | |||||
| inline static PyObject* pycast(T* ptr) { | |||||
| return (PyObject*)((char*)ptr - offsetof(wrap_t, storage)); | |||||
| } | |||||
| private: | |||||
| // method wrapper | |||||
| enum struct meth_type { | |||||
| noarg, | |||||
| varkw, | |||||
| fastcall, | |||||
| singarg | |||||
| }; | |||||
| template<auto f> | |||||
| struct detect_meth_type { | |||||
| static constexpr meth_type value = []() { | |||||
| using F = decltype(f); | |||||
| static_assert(std::is_member_function_pointer_v<F>); | |||||
| if constexpr (std::is_invocable_v<F, T>) { | |||||
| return meth_type::noarg; | |||||
| } else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) { | |||||
| return meth_type::varkw; | |||||
| } else if constexpr (std::is_invocable_v<F, T, PyObject*const*, Py_ssize_t>) { | |||||
| return meth_type::fastcall; | |||||
| } else if constexpr (std::is_invocable_v<F, T, PyObject*>) { | |||||
| return meth_type::singarg; | |||||
| } else { | |||||
| static_assert(!std::is_same_v<F, F>); | |||||
| } | |||||
| }(); | |||||
| }; | |||||
| template<meth_type, auto f> | |||||
| struct meth {}; | |||||
| template<auto f> | |||||
| struct meth<meth_type::noarg, f> { | |||||
| static constexpr int flags = METH_NOARGS; | |||||
| static PyObject* impl(PyObject* self, PyObject*) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| CVT_RET_PYOBJ((inst->*f)()); | |||||
| } | |||||
| }; | |||||
| template<auto f> | |||||
| struct meth<meth_type::varkw, f> { | |||||
| static constexpr int flags = METH_VARARGS | METH_KEYWORDS; | |||||
| static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| CVT_RET_PYOBJ((inst->*f)(args, kwargs)); | |||||
| } | |||||
| }; | |||||
| template<auto f> | |||||
| struct meth<meth_type::fastcall, f> { | |||||
| #ifdef METH_FASTCALL | |||||
| static constexpr int flags = METH_FASTCALL; | |||||
| static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| CVT_RET_PYOBJ((inst->*f)(args, nargs)); | |||||
| } | |||||
| #else | |||||
| static constexpr int flags = METH_VARARGS; | |||||
| static PyObject* impl(PyObject* self, PyObject* args) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| auto* arr = &PyTuple_GET_ITEM(args, 0); | |||||
| auto size = PyTuple_GET_SIZE(args); | |||||
| CVT_RET_PYOBJ((inst->*f)(arr, size)); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template<auto f> | |||||
| struct meth<meth_type::singarg, f> { | |||||
| static constexpr int flags = METH_O; | |||||
| static PyObject* impl(PyObject* self, PyObject* obj) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| CVT_RET_PYOBJ((inst->*f)(obj)); | |||||
| } | |||||
| }; | |||||
| template<auto f> | |||||
| static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) { | |||||
| using M = meth<detect_meth_type<f>::value, f>; | |||||
| return {name, (PyCFunction)M::impl, M::flags, doc}; | |||||
| } | |||||
| // polyfills | |||||
| struct tp_new { | |||||
| static constexpr bool provided = HAS_MEMBER(T, tp_new); | |||||
| static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>; | |||||
| static constexpr bool noarg = std::is_default_constructible_v<T>; | |||||
| template<typename = void> | |||||
| static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | |||||
| auto* self = type->tp_alloc(type, 0); | |||||
| auto* ptr = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| if constexpr (varkw) { | |||||
| new(ptr) T(args, kwargs); | |||||
| } else { | |||||
| new(ptr) T(); | |||||
| } | |||||
| return self; | |||||
| } | |||||
| static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new; | |||||
| else if constexpr (varkw || noarg) return impl<>; | |||||
| else return nullptr;}(); | |||||
| }; | |||||
| struct tp_dealloc { | |||||
| static constexpr bool provided = HAS_MEMBER(T, tp_dealloc); | |||||
| template<typename = void> | |||||
| static void impl(PyObject* self) { | |||||
| reinterpret_cast<wrap_t*>(self)->inst()->~T(); | |||||
| Py_TYPE(self)->tp_free(self); | |||||
| } | |||||
| static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc; | |||||
| else return impl<>;}(); | |||||
| }; | |||||
| struct tp_call { | |||||
| static constexpr bool valid = HAS_MEMBER(T, tp_call); | |||||
| static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||||
| [](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||||
| template<typename = void> | |||||
| static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||||
| auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
| CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||||
| } | |||||
| static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||||
| else if constexpr (valid) return impl<>; | |||||
| else return nullptr;}(); | |||||
| }; | |||||
| public: | |||||
| class TypeBuilder { | |||||
| std::vector<PyMethodDef> m_methods; | |||||
| PyTypeObject m_type; | |||||
| bool m_finalized = false; | |||||
| bool m_ready = false; | |||||
| void check_finalized() { | |||||
| if (m_finalized) { | |||||
| throw std::runtime_error("type is already finalized"); | |||||
| } | |||||
| } | |||||
| public: | |||||
| TypeBuilder(const TypeBuilder&) = delete; | |||||
| TypeBuilder& operator=(const TypeBuilder&) = delete; | |||||
| TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { | |||||
| // static_assert(HAS_MEMBER(T, tp_name)); | |||||
| if constexpr (HAS_MEMBER(T, tp_name)) { | |||||
| m_type.tp_name = T::tp_name; | |||||
| } | |||||
| m_type.tp_dealloc = tp_dealloc::value; | |||||
| m_type.tp_call = tp_call::value; | |||||
| m_type.tp_basicsize = sizeof(wrap_t); | |||||
| m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
| m_type.tp_new = tp_new::value; | |||||
| } | |||||
| PyTypeObject* operator->() { | |||||
| return &m_type; | |||||
| } | |||||
| bool ready() const { | |||||
| return m_ready; | |||||
| } | |||||
| PyObject* finalize() { | |||||
| if (!m_finalized) { | |||||
| if (m_methods.size()) { | |||||
| m_methods.push_back({0}); | |||||
| if (m_type.tp_methods) { | |||||
| PyErr_SetString(PyExc_SystemError, "tp_method is already set"); | |||||
| return nullptr; | |||||
| } | |||||
| m_type.tp_methods = &m_methods[0]; | |||||
| } | |||||
| if (PyType_Ready(&m_type)) { | |||||
| return nullptr; | |||||
| } | |||||
| m_ready = true; | |||||
| } | |||||
| return (PyObject*)&m_type; | |||||
| } | |||||
| template<auto f> | |||||
| TypeBuilder& def(const char* name, const char* doc = nullptr) { | |||||
| check_finalized(); | |||||
| m_methods.push_back(make_meth_def<f>(name, doc)); | |||||
| return *this; | |||||
| } | |||||
| }; | |||||
| static TypeBuilder& type() { | |||||
| static TypeBuilder type_helper; | |||||
| return type_helper; | |||||
| } | |||||
| }; | |||||
| } // namespace pyext17 | |||||
| #undef HAS_MEMBER_TYPE | |||||
| #undef HAS_MEMBER | |||||
| #undef CVT_RET_PYOBJ | |||||
| @@ -0,0 +1,58 @@ | |||||
| from megengine.core.tensor.multipledispatch import Dispatcher | |||||
| def test_register_many(): | |||||
| f = Dispatcher("f") | |||||
| log = [] | |||||
| @f.register() | |||||
| def _(x: int): | |||||
| log.append("a") | |||||
| return log[-1] | |||||
| @f.register() | |||||
| def _(x: int): | |||||
| log.append("b") | |||||
| return log[-1] | |||||
| assert f(0) == "b" | |||||
| assert log == ["b"] | |||||
| def test_return_not_implemented(): | |||||
| f = Dispatcher("f") | |||||
| log = [] | |||||
| @f.register() | |||||
| def _(x: int): | |||||
| log.append("a") | |||||
| return log[-1] | |||||
| @f.register() | |||||
| def _(x: int): | |||||
| log.append("b") | |||||
| return NotImplemented | |||||
| assert f(0) == "a" | |||||
| assert log == ["b", "a"] | |||||
| def test_super(): | |||||
| f = Dispatcher("f") | |||||
| log = [] | |||||
| @f.register() | |||||
| def _(x: int): | |||||
| log.append("a") | |||||
| return log[-1] | |||||
| @f.register() | |||||
| def _(x: int): | |||||
| log.append("b") | |||||
| return f.super(x) | |||||
| assert f(0) == "a" | |||||
| assert log == ["b", "a"] | |||||