GitOrigin-RevId: 16da6d1491
tags/v1.10.0
| @@ -7,9 +7,7 @@ from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| from .. import _config | from .. import _config | ||||
| from .._imperative_rt.common import CompNode | |||||
| from .._imperative_rt.core2 import ( | from .._imperative_rt.core2 import ( | ||||
| SymbolVar, | |||||
| Tensor, | Tensor, | ||||
| apply, | apply, | ||||
| astype_cpp, | astype_cpp, | ||||
| @@ -17,9 +15,11 @@ from .._imperative_rt.core2 import ( | |||||
| broadcast_cpp, | broadcast_cpp, | ||||
| getitem_cpp, | getitem_cpp, | ||||
| matmul_cpp, | matmul_cpp, | ||||
| reshape_cpp, | |||||
| setitem_cpp, | |||||
| squeeze_cpp, | |||||
| transpose_cpp, | |||||
| ) | ) | ||||
| from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | |||||
| from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from . import amp | from . import amp | ||||
| from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | ||||
| @@ -189,9 +189,7 @@ def _todo(*_): | |||||
| def _expand_args(args): | def _expand_args(args): | ||||
| if len(args) == 1: | if len(args) == 1: | ||||
| if isinstance( | |||||
| args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), | |||||
| ): | |||||
| if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||||
| args = args[0] | args = args[0] | ||||
| return args | return args | ||||
| @@ -8,7 +8,6 @@ import numpy as np | |||||
| from .._imperative_rt import make_const | from .._imperative_rt import make_const | ||||
| from .._imperative_rt.core2 import ( | from .._imperative_rt.core2 import ( | ||||
| Const, | Const, | ||||
| SymbolVar, | |||||
| Tensor, | Tensor, | ||||
| _get_convert_inputs, | _get_convert_inputs, | ||||
| _set_convert_inputs, | _set_convert_inputs, | ||||
| @@ -77,7 +76,7 @@ def result_type(*args): | |||||
| def isscalar(x): | def isscalar(x): | ||||
| if isinstance(x, (Tensor, SymbolVar)): | |||||
| if isinstance(x, Tensor): | |||||
| return x._isscalar() | return x._isscalar() | ||||
| return np.isscalar(x) | return np.isscalar(x) | ||||
| @@ -283,7 +282,7 @@ def interpret_subgraph(func, dtype, device): | |||||
| return results | return results | ||||
| def apply_const(value, dtype=dtype, device=device): | def apply_const(value, dtype=dtype, device=device): | ||||
| return Const(value, dtype, device, None) | |||||
| return Const(value, dtype, device) | |||||
| outputs, outputs_has_grad = func(args, apply_expr, apply_const) | outputs, outputs_has_grad = func(args, apply_expr, apply_const) | ||||
| outputs = [ | outputs = [ | ||||
| @@ -2,7 +2,7 @@ | |||||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | ||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
| from ..core.tensor.array_method import _elwise | from ..core.tensor.array_method import _elwise | ||||
| @@ -538,7 +538,7 @@ def topk( | |||||
| op = builtin.TopK(mode=mode) | op = builtin.TopK(mode=mode) | ||||
| if not isinstance(k, Tensor): | if not isinstance(k, Tensor): | ||||
| k = Const(k, "int32", inp.device, None) | |||||
| k = Const(k, "int32", inp.device) | |||||
| if len(inp.shape) == 1: | if len(inp.shape) == 1: | ||||
| if kth_only: | if kth_only: | ||||
| @@ -1222,7 +1222,7 @@ def batch_norm( | |||||
| raise ValueError("Invalid param_dim {}".format(param_dim)) | raise ValueError("Invalid param_dim {}".format(param_dim)) | ||||
| if x is None: | if x is None: | ||||
| x = Const(value, inp.dtype, inp.device, None) | |||||
| x = Const(value, inp.dtype, inp.device) | |||||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
| (result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
| return result | return result | ||||
| @@ -1446,7 +1446,7 @@ def sync_batch_norm( | |||||
| def _make_full_if_none(x, value): | def _make_full_if_none(x, value): | ||||
| if x is None: | if x is None: | ||||
| x = Const(value, inp.dtype, _device, None) | |||||
| x = Const(value, inp.dtype, _device) | |||||
| (result,) = apply(builtin.Broadcast(), x, reduce_shape) | (result,) = apply(builtin.Broadcast(), x, reduce_shape) | ||||
| return result | return result | ||||
| elif x.ndim == 1: | elif x.ndim == 1: | ||||
| @@ -7,7 +7,6 @@ import numpy as np | |||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
| Const, | Const, | ||||
| SymbolVar, | |||||
| apply, | apply, | ||||
| broadcast_cpp, | broadcast_cpp, | ||||
| dtype_promotion, | dtype_promotion, | ||||
| @@ -151,7 +150,7 @@ def full( | |||||
| shape = (shape,) | shape = (shape,) | ||||
| if device is None: | if device is None: | ||||
| device = get_default_device() | device = get_default_device() | ||||
| x = Const(value, dtype, device, None) | |||||
| x = Const(value, dtype, device) | |||||
| if type(shape) in (list, tuple) and len(shape) == 0: | if type(shape) in (list, tuple) and len(shape) == 0: | ||||
| return x | return x | ||||
| return broadcast_to(x, shape) | return broadcast_to(x, shape) | ||||
| @@ -216,7 +215,7 @@ def zeros( | |||||
| return full(shape, 0.0, dtype=dtype, device=device) | return full(shape, 0.0, dtype=dtype, device=device) | ||||
| def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
| def zeros_like(inp: Tensor) -> Tensor: | |||||
| r"""Returns a tensor filled with zeros with the same shape and data type as input tensor. | r"""Returns a tensor filled with zeros with the same shape and data type as input tensor. | ||||
| Args: | Args: | ||||
| @@ -235,7 +234,7 @@ def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
| return full_like(inp, 0.0) | return full_like(inp, 0.0) | ||||
| def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
| def ones_like(inp: Tensor) -> Tensor: | |||||
| r"""Returns a tensor filled with ones with the same shape and data type as input tensor. | r"""Returns a tensor filled with ones with the same shape and data type as input tensor. | ||||
| Args: | Args: | ||||
| @@ -253,9 +252,7 @@ def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
| return full_like(inp, 1.0) | return full_like(inp, 1.0) | ||||
| def full_like( | |||||
| inp: Union[Tensor, SymbolVar], value: Union[int, float] | |||||
| ) -> Union[Tensor, SymbolVar]: | |||||
| def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
| r"""Returns a tensor filled with given value with the same shape as input tensor. | r"""Returns a tensor filled with given value with the same shape as input tensor. | ||||
| Args: | Args: | ||||
| @@ -272,7 +269,7 @@ def full_like( | |||||
| Tensor([[2 2 2] | Tensor([[2 2 2] | ||||
| [2 2 2]], dtype=int32, device=xpux:0) | [2 2 2]], dtype=int32, device=xpux:0) | ||||
| """ | """ | ||||
| x = Const(value, inp.dtype, inp.device, inp) | |||||
| x = Const(value, inp.dtype, inp.device) | |||||
| if inp.ndim == 0: | if inp.ndim == 0: | ||||
| return x | return x | ||||
| return broadcast_to(x, inp.shape) | return broadcast_to(x, inp.shape) | ||||
| @@ -668,9 +665,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||||
| >>> print(v.numpy(), index.numpy()) | >>> print(v.numpy(), index.numpy()) | ||||
| [1. 4.] [0 3] | [1. 4.] [0 3] | ||||
| """ | """ | ||||
| if not isinstance(x, (Tensor, SymbolVar)): | |||||
| if not isinstance(x, Tensor): | |||||
| raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
| if not isinstance(mask, (Tensor, SymbolVar)): | |||||
| if not isinstance(mask, Tensor): | |||||
| raise TypeError("mask must be a tensor") | raise TypeError("mask must be a tensor") | ||||
| if mask.dtype != np.bool_: | if mask.dtype != np.bool_: | ||||
| raise ValueError("mask must be bool") | raise ValueError("mask must be bool") | ||||
| @@ -843,15 +840,11 @@ def linspace( | |||||
| if not (cur_device is None or device == cur_device): | if not (cur_device is None or device == cur_device): | ||||
| raise ("ambiguous device for linspace opr") | raise ("ambiguous device for linspace opr") | ||||
| is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) | |||||
| if any(is_symbolvar) and not all(is_symbolvar): | |||||
| raise TypeError("start, stop and num should all be VarNode or none of them") | |||||
| if not isinstance(start, (Tensor, SymbolVar)): | |||||
| if not isinstance(start, Tensor): | |||||
| start = Tensor(start, device=device) | start = Tensor(start, device=device) | ||||
| if not isinstance(stop, (Tensor, SymbolVar)): | |||||
| if not isinstance(stop, Tensor): | |||||
| stop = Tensor(stop, device=device) | stop = Tensor(stop, device=device) | ||||
| if not isinstance(num, (Tensor, SymbolVar)): | |||||
| if not isinstance(num, Tensor): | |||||
| num = Tensor(num, device=device) | num = Tensor(num, device=device) | ||||
| op = builtin.Linspace(comp_node=device) | op = builtin.Linspace(comp_node=device) | ||||
| @@ -901,9 +894,12 @@ def arange( | |||||
| if stop is None: | if stop is None: | ||||
| start, stop = 0, start | start, stop = 0, start | ||||
| start = Tensor(start, dtype="float32") | |||||
| stop = Tensor(stop, dtype="float32") | |||||
| step = Tensor(step, dtype="float32") | |||||
| if not isinstance(start, Tensor): | |||||
| start = Tensor(start, dtype="float32") | |||||
| if not isinstance(stop, Tensor): | |||||
| stop = Tensor(stop, dtype="float32") | |||||
| if not isinstance(step, Tensor): | |||||
| step = Tensor(step, dtype="float32") | |||||
| num = ceil((stop - start) / step) | num = ceil((stop - start) / step) | ||||
| stop = start + step * (num - 1) | stop = start + step * (num - 1) | ||||
| @@ -7,11 +7,11 @@ small_tensor_cache = {} | |||||
| def _get_scalar_tensor_with_value(value, dtype=None, device=None): | def _get_scalar_tensor_with_value(value, dtype=None, device=None): | ||||
| global small_tensor_cache | global small_tensor_cache | ||||
| if is_tracing(): | if is_tracing(): | ||||
| ret = Const(value, dtype, device, None) | |||||
| ret = Const(value, dtype, device) | |||||
| else: | else: | ||||
| cache_key = (value, dtype, device) | cache_key = (value, dtype, device) | ||||
| if cache_key not in small_tensor_cache: | if cache_key not in small_tensor_cache: | ||||
| ret = Const(value, dtype, device, None) | |||||
| ret = Const(value, dtype, device) | |||||
| small_tensor_cache[cache_key] = ret | small_tensor_cache[cache_key] = ret | ||||
| else: | else: | ||||
| ret = small_tensor_cache[cache_key] | ret = small_tensor_cache[cache_key] | ||||
| @@ -154,6 +154,8 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| @name.setter | @name.setter | ||||
| def name(self, name): | def name(self, name): | ||||
| self._custom_name = name | self._custom_name = name | ||||
| if name == None: | |||||
| name = "" | |||||
| self._name = self._prefix + "." + name if self._prefix else name | self._name = self._prefix + "." + name if self._prefix else name | ||||
| self._set_name(self._name) | self._set_name(self._name) | ||||
| @@ -756,7 +756,7 @@ class Constant(Expr): | |||||
| def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
| if isinstance(self.value, RawTensor): | if isinstance(self.value, RawTensor): | ||||
| return (Const(self.value.numpy(), None, None, None),) | |||||
| return (Const(self.value.numpy(), None, None),) | |||||
| return (self.value,) | return (self.value,) | ||||
| def __repr__(self): | def __repr__(self): | ||||
| @@ -395,7 +395,7 @@ class Network: | |||||
| for ind, var in enumerate(opr.outputs): | for ind, var in enumerate(opr.outputs): | ||||
| var.owner = repl_dict[opr] | var.owner = repl_dict[opr] | ||||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | ||||
| var.var = repl_dict[opr].outputs[ind].var | |||||
| var._reset_var(repl_dict[opr].outputs[ind].var) | |||||
| repl_dict[opr].outputs = opr.outputs | repl_dict[opr].outputs = opr.outputs | ||||
| self._compile() | self._compile() | ||||
| @@ -6,11 +6,11 @@ from typing import Sequence | |||||
| import numpy as np | import numpy as np | ||||
| from ..core import _imperative_rt as rt | from ..core import _imperative_rt as rt | ||||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||||
| from ..core._imperative_rt.core2 import apply, set_py_varnode_type | |||||
| from ..core._trace_option import use_symbolic_shape | from ..core._trace_option import use_symbolic_shape | ||||
| from ..core._wrap import Device | from ..core._wrap import Device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.tensor.array_method import ArrayMethodMixin | |||||
| from ..tensor import Tensor | |||||
| from .comp_graph_tools import replace_vars | from .comp_graph_tools import replace_vars | ||||
| from .module_stats import ( | from .module_stats import ( | ||||
| preprocess_receptive_field, | preprocess_receptive_field, | ||||
| @@ -23,26 +23,72 @@ class NetworkNode: | |||||
| pass | pass | ||||
| class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||||
| pass | |||||
| class VarNode(NetworkNode, Tensor): | |||||
| _users = None | |||||
| _owner = None | |||||
| _name = None | |||||
| _id = None | |||||
| def __new__(cls, var, *, owner_opr=None, name=None): | |||||
| obj = Tensor.__new__(cls, var) | |||||
| return obj | |||||
| class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
| def __init__(self, var=None, *, owner_opr=None, name=None): | |||||
| SymbolVar.__init__(self, var) | |||||
| self.users = [] # List[OpNode] | |||||
| self.owner = owner_opr | |||||
| def __init__(self, var, *, owner_opr=None, name=None): | |||||
| self._owner = owner_opr | |||||
| self.name = name | self.name = name | ||||
| self.id = id(self) | |||||
| @classmethod | @classmethod | ||||
| def load(cls, sym_var, owner_opr): | def load(cls, sym_var, owner_opr): | ||||
| obj = cls() | |||||
| obj = cls(sym_var) | |||||
| obj.var = sym_var # mgb varnode | obj.var = sym_var # mgb varnode | ||||
| obj.name = sym_var.name | obj.name = sym_var.name | ||||
| obj.owner = owner_opr | obj.owner = owner_opr | ||||
| return obj | return obj | ||||
| @property | |||||
| def users(self): | |||||
| if self._users is None: | |||||
| self._users = [] | |||||
| return self._users | |||||
| @property | |||||
| def owner(self): | |||||
| return self._owner | |||||
| @owner.setter | |||||
| def owner(self, owner): | |||||
| self._owner = owner | |||||
| @property | |||||
| def id(self): | |||||
| if self._id is None: | |||||
| self._id = id(self) | |||||
| return self._id | |||||
| @property | |||||
| def var(self): | |||||
| return super().var() | |||||
| @var.setter | |||||
| def var(self, var): | |||||
| self._reset(var) | |||||
| def _reset(self, other): | |||||
| if not isinstance(other, Tensor): | |||||
| other = VarNode(other) | |||||
| super()._reset(other) | |||||
| self.owner = None | |||||
| def _reset_var(self, var): | |||||
| origin_owner = self.owner | |||||
| self.var = var | |||||
| self.var.name = self.name | |||||
| self.owner = origin_owner | |||||
| @property | |||||
| def graph(self): | |||||
| return super().graph() | |||||
| def _get_var_shape(self, axis=None): | def _get_var_shape(self, axis=None): | ||||
| opdef = ( | opdef = ( | ||||
| builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis) | builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis) | ||||
| @@ -77,14 +123,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
| return rst | return rst | ||||
| return self._get_var_shape() if self.var else None | return self._get_var_shape() if self.var else None | ||||
| @property | |||||
| def dtype(self): | |||||
| return self.var.dtype if self.var else None | |||||
| @property | |||||
| def ndim(self): | |||||
| return super().ndim | |||||
| def __bool__(self): | def __bool__(self): | ||||
| return False | return False | ||||
| @@ -92,27 +130,11 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
| __int__ = None | __int__ = None | ||||
| __float__ = None | __float__ = None | ||||
| __complex__ = None | __complex__ = None | ||||
| __repr__ = lambda self: "VarNode:" + self.name | |||||
| def __hash__(self): | def __hash__(self): | ||||
| return id(self) | return id(self) | ||||
| def numpy(self): | |||||
| return super().numpy() | |||||
| def _reset(self, other): | |||||
| if not isinstance(other, VarNode): | |||||
| assert self.graph, "VarNode _reset must have graph" | |||||
| node = ImmutableTensor(other, graph=self.graph) | |||||
| node.compile(self.graph) | |||||
| other = node.outputs[0] | |||||
| if self.owner is not None: | |||||
| idx = self.owner.outputs.index(self) | |||||
| self.owner.outputs[idx] = VarNode( | |||||
| self.var, owner_opr=self.owner, name=self.var.name | |||||
| ) | |||||
| self.var = other.var | |||||
| self.owner = None | |||||
| def set_owner_opr(self, owner_opr): | def set_owner_opr(self, owner_opr): | ||||
| self.owner = owner_opr | self.owner = owner_opr | ||||
| @@ -158,8 +180,7 @@ class OpNode(NetworkNode): | |||||
| assert len(outputs) == len(self.outputs) | assert len(outputs) == len(self.outputs) | ||||
| self._opr = outputs[0].owner | self._opr = outputs[0].owner | ||||
| for i in range(len(self.outputs)): | for i in range(len(self.outputs)): | ||||
| self.outputs[i].var = outputs[i] | |||||
| self.outputs[i].var.name = self.outputs[i].name | |||||
| self.outputs[i]._reset_var(outputs[i]) | |||||
| assert self.outputs[i].owner is self | assert self.outputs[i].owner is self | ||||
| def add_inp_var(self, x): | def add_inp_var(self, x): | ||||
| @@ -214,8 +235,9 @@ class Host2DeviceCopy(OpNode): | |||||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | ||||
| self._opr = outputs.owner | self._opr = outputs.owner | ||||
| if len(self.outputs) == 0: | if len(self.outputs) == 0: | ||||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
| self.outputs[0].var = outputs | |||||
| self.outputs.append(VarNode(outputs, owner_opr=self, name=self.name)) | |||||
| else: | |||||
| self.outputs[0]._reset_var(outputs) | |||||
| assert self.outputs[0].owner is self | assert self.outputs[0].owner is self | ||||
| @@ -262,8 +284,9 @@ class ConstOpBase(OpNode): | |||||
| data = data.astype(np.int32) | data = data.astype(np.int32) | ||||
| varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) | varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) | ||||
| if len(self.outputs) == 0: | if len(self.outputs) == 0: | ||||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
| self.outputs[0].var = varnode | |||||
| self.outputs.append(VarNode(varnode, owner_opr=self, name=self.name)) | |||||
| else: | |||||
| self.outputs[0]._reset_var(varnode) | |||||
| self._opr = varnode.owner | self._opr = varnode.owner | ||||
| @classmethod | @classmethod | ||||
| @@ -313,7 +336,7 @@ class ReadOnlyOpNode(OpNode): | |||||
| if bool(repl_dict): | if bool(repl_dict): | ||||
| out_vars = replace_vars(self._opr.outputs, repl_dict) | out_vars = replace_vars(self._opr.outputs, repl_dict) | ||||
| for ind, o in enumerate(self.outputs): | for ind, o in enumerate(self.outputs): | ||||
| o.var = out_vars[ind] | |||||
| o._reset_var(out_vars[ind]) | |||||
| class Elemwise(OpNode): | class Elemwise(OpNode): | ||||
| @@ -785,3 +808,6 @@ class AssertEqual(OpNode): | |||||
| class CvtColorForward(OpNode): | class CvtColorForward(OpNode): | ||||
| type = "CvtColor" | type = "CvtColor" | ||||
| opdef = builtin.CvtColor | opdef = builtin.CvtColor | ||||
| set_py_varnode_type(VarNode) | |||||
| @@ -114,6 +114,8 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) { | |||||
| } | } | ||||
| } | } | ||||
| py::object Py_Varnode = py::none(); | |||||
| void init_graph_rt(py::module m) { | void init_graph_rt(py::module m) { | ||||
| static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{ | static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{ | ||||
| std::make_unique<mgb::OprFootprint>()}; | std::make_unique<mgb::OprFootprint>()}; | ||||
| @@ -124,40 +126,44 @@ void init_graph_rt(py::module m) { | |||||
| def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); | def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); | ||||
| py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | |||||
| .def_property_readonly( | |||||
| "owner", [](cg::VarNode* v) { return v->owner_opr(); }) | |||||
| .def_property_readonly( | |||||
| "graph", [](cg::VarNode* v) { return v->owner_graph(); }) | |||||
| .def_property( | |||||
| "name", py::overload_cast<>(&VarNode::name, py::const_), | |||||
| py::overload_cast<std::string>(&VarNode::name)) | |||||
| .def_property_readonly("dtype", [](cg::VarNode* v) { return v->dtype(); }) | |||||
| .def_property_readonly( | |||||
| "comp_node", [](cg::VarNode* v) { return v->comp_node(); }) | |||||
| .def_property_readonly( | |||||
| "shape", | |||||
| [](cg::VarNode* v) -> const TensorShape* { | |||||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||||
| return mgr.infer_shape_fallible(v); | |||||
| }) | |||||
| .def_property_readonly( | |||||
| "value", | |||||
| [](cg::VarNode* v) -> py::object { | |||||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||||
| auto&& type = mgr.get_infer_type(v); | |||||
| using InferType = cg::static_infer::InferType; | |||||
| if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||||
| return py::none(); | |||||
| } | |||||
| auto* val = mgr.infer_value_fallible(v); | |||||
| if (!val) { | |||||
| return py::none(); | |||||
| } | |||||
| return py::cast(*val).attr("numpy")(); | |||||
| }) | |||||
| .def_property_readonly("id", [](cg::VarNode* v) { return (v->id()); }) | |||||
| .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); }); | |||||
| Py_Varnode = | |||||
| py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | |||||
| .def_property_readonly( | |||||
| "owner", [](cg::VarNode* v) { return v->owner_opr(); }) | |||||
| .def_property_readonly( | |||||
| "graph", [](cg::VarNode* v) { return v->owner_graph(); }) | |||||
| .def_property( | |||||
| "name", py::overload_cast<>(&VarNode::name, py::const_), | |||||
| py::overload_cast<std::string>(&VarNode::name)) | |||||
| .def_property_readonly( | |||||
| "dtype", [](cg::VarNode* v) { return v->dtype(); }) | |||||
| .def_property_readonly( | |||||
| "comp_node", [](cg::VarNode* v) { return v->comp_node(); }) | |||||
| .def_property_readonly( | |||||
| "shape", | |||||
| [](cg::VarNode* v) -> const TensorShape* { | |||||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||||
| return mgr.infer_shape_fallible(v); | |||||
| }) | |||||
| .def_property_readonly( | |||||
| "value", | |||||
| [](cg::VarNode* v) -> py::object { | |||||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||||
| auto&& type = mgr.get_infer_type(v); | |||||
| using InferType = cg::static_infer::InferType; | |||||
| if (!(type.value & | |||||
| (InferType::CONST | InferType::RT_STATIC))) { | |||||
| return py::none(); | |||||
| } | |||||
| auto* val = mgr.infer_value_fallible(v); | |||||
| if (!val) { | |||||
| return py::none(); | |||||
| } | |||||
| return py::cast(*val).attr("numpy")(); | |||||
| }) | |||||
| .def_property_readonly( | |||||
| "id", [](cg::VarNode* v) { return (v->id()); }) | |||||
| .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); }); | |||||
| py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>( | py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>( | ||||
| m, "OperatorNode") | m, "OperatorNode") | ||||
| @@ -8,6 +8,9 @@ | |||||
| #include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
| #include "megbrain/plugin/opr_footprint.h" | #include "megbrain/plugin/opr_footprint.h" | ||||
| namespace py = pybind11; | |||||
| extern py::object Py_Varnode; | |||||
| template <typename T> | template <typename T> | ||||
| class GraphNodePtr { | class GraphNodePtr { | ||||
| std::shared_ptr<mgb::cg::ComputingGraph> m_graph; | std::shared_ptr<mgb::cg::ComputingGraph> m_graph; | ||||
| @@ -48,58 +48,11 @@ namespace mgb::imperative::python { | |||||
| namespace { | namespace { | ||||
| WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | ||||
| struct SymbolVarContext { | |||||
| TransformationContext context; | |||||
| std::shared_ptr<SymbolTransformation> symbol_tsf; | |||||
| std::shared_ptr<ScalarTransformation> scalar_tsf; | |||||
| std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | |||||
| std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf; | |||||
| SymbolVarContext(cg::ComputingGraph* graph) { | |||||
| symbol_tsf = std::make_shared<SymbolTransformation>(graph); | |||||
| scalar_tsf = std::make_shared<ScalarTransformation>(); | |||||
| dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | |||||
| dim_expansion_tsf = std::make_shared<DimExpansionTransformation>(); | |||||
| Transformation::swap_context(context); | |||||
| } | |||||
| void init() { | |||||
| symbol_tsf->register_at(Transformation::top()); | |||||
| scalar_tsf->register_at(Transformation::top()); | |||||
| dtype_promote_tsf->register_at(Transformation::top()); | |||||
| dim_expansion_tsf->register_at(Transformation::top()); | |||||
| } | |||||
| ValueRef symvar2val(py::handle py_symbol_var) { | |||||
| auto* symbol_var = py_symbol_var.cast<PySymbolVar*>(); | |||||
| ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node); | |||||
| if (symbol_var->is_scalar) { | |||||
| value = scalar_tsf->value_type().make(value); | |||||
| } | |||||
| return value; | |||||
| } | |||||
| py::object val2symvar(py::handle typeobj, ValueRef value) { | |||||
| bool is_scalar = false; | |||||
| if (auto* scalar_value = value.as(scalar_tsf->value_type())) { | |||||
| value = scalar_value->value(); | |||||
| is_scalar = true; | |||||
| } | |||||
| auto* node = value.cast(symbol_tsf->value_type()).node(); | |||||
| auto py_symbol_var = | |||||
| typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||||
| py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar; | |||||
| return py_symbol_var; | |||||
| } | |||||
| ~SymbolVarContext() { Transformation::swap_context(context); } | |||||
| }; | |||||
| } // namespace | } // namespace | ||||
| interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | ||||
| PyTypeObject* py_tensor_type = nullptr; | PyTypeObject* py_tensor_type = nullptr; | ||||
| PyTypeObject* py_varnode_type = nullptr; | |||||
| pybind11::handle py_device_type = nullptr; | pybind11::handle py_device_type = nullptr; | ||||
| PyObject* cpp_use_symbolic_shape; | PyObject* cpp_use_symbolic_shape; | ||||
| @@ -136,22 +89,6 @@ PyObject* py_apply( | |||||
| auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | ||||
| SmallVector<ValueRef, 8> tensors(nargs); | SmallVector<ValueRef, 8> tensors(nargs); | ||||
| SmallVector<bool, 8> is_symbol_var(nargs, false); | |||||
| ComputingGraph* cg = nullptr; | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| if ((!TensorWrapper::try_cast(args[i])) && | |||||
| py::isinstance<PySymbolVar>(py::handle(args[i]))) { | |||||
| is_symbol_var[i] = true; | |||||
| ComputingGraph* cur_cg = | |||||
| py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph(); | |||||
| if (cg == nullptr) { | |||||
| cg = cur_cg; | |||||
| } else { | |||||
| mgb_assert(cg == cur_cg); | |||||
| } | |||||
| } | |||||
| } | |||||
| mgb::CompNode target_cn; | mgb::CompNode target_cn; | ||||
| mgb::DType target_dtype; | mgb::DType target_dtype; | ||||
| @@ -174,35 +111,11 @@ PyObject* py_apply( | |||||
| } | } | ||||
| }; | }; | ||||
| if (cg != nullptr) { | |||||
| // swap to a special context to reuse scalar handle | |||||
| size_t symbol_var_idx = 8; | |||||
| SymbolVarContext context(cg); | |||||
| context.init(); | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| if (is_symbol_var[i]) { | |||||
| symbol_var_idx = i; | |||||
| tensors[i] = context.symvar2val(args[i]); | |||||
| } else if ( | |||||
| DTypePromoteCfg::convert_input_enabled && | |||||
| op->same_type<Elemwise>()) { | |||||
| tensors[i] = convert_pyinput_to_tensor(i); | |||||
| } else { | |||||
| PyErr_SetString( | |||||
| PyExc_TypeError, "py_apply expects tensor as inputs"); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| auto outputs = imperative::apply(*op, tensors); | |||||
| auto ret = pybind11::tuple(outputs.size()); | |||||
| auto typeobj = py::handle(args[symbol_var_idx]).get_type(); | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| ret[i] = context.val2symvar(typeobj, outputs[i]); | |||||
| } | |||||
| return ret.release().ptr(); | |||||
| } | |||||
| bool is_varnode_apply = false; | |||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| if (PyObject_TypeCheck(args[i], py_varnode_type)) { | |||||
| is_varnode_apply = true; | |||||
| } | |||||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | ||||
| tensors[i] = tw->m_tensor->data(); | tensors[i] = tw->m_tensor->data(); | ||||
| } else if ( | } else if ( | ||||
| @@ -218,8 +131,9 @@ PyObject* py_apply( | |||||
| auto outputs = [&] { return imperative::apply(*op, tensors); }(); | auto outputs = [&] { return imperative::apply(*op, tensors); }(); | ||||
| size_t nout = outputs.size(); | size_t nout = outputs.size(); | ||||
| auto ret = py::tuple(nout); | auto ret = py::tuple(nout); | ||||
| PyTypeObject* py_type = is_varnode_apply ? py_varnode_type : py_tensor_type; | |||||
| for (size_t i = 0; i < nout; ++i) { | for (size_t i = 0; i < nout; ++i) { | ||||
| ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i])); | |||||
| ret[i] = TensorWrapper::make(py_type, std::move(outputs[i])); | |||||
| } | } | ||||
| return ret.release().ptr(); | return ret.release().ptr(); | ||||
| } | } | ||||
| @@ -622,9 +536,17 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| CreateTensor::Kind kind = is_const ? CreateTensor::Const | CreateTensor::Kind kind = is_const ? CreateTensor::Const | ||||
| : no_cache ? CreateTensor::Unique | : no_cache ? CreateTensor::Unique | ||||
| : CreateTensor::Common; | : CreateTensor::Common; | ||||
| auto&& hval = pyobj2hval(data, cn, dtype); | |||||
| auto val = imperative::apply( | |||||
| CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0]; | |||||
| ValueRef val; | |||||
| if (py::isinstance(data, Py_Varnode)) { | |||||
| cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>(); | |||||
| val = imperative::apply( | |||||
| CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0]; | |||||
| } else { | |||||
| auto&& hval = pyobj2hval(data, cn, dtype); | |||||
| val = imperative::apply( | |||||
| CreateTensor(kind, cn, hval.dtype, hval.shape), | |||||
| hval.storage)[0]; | |||||
| } | |||||
| m_tensor.emplace(val); | m_tensor.emplace(val); | ||||
| } | } | ||||
| @@ -734,6 +656,20 @@ PyObject* TensorWrapper::isscalar() { | |||||
| } | } | ||||
| } | } | ||||
| PyObject* TensorWrapper::_var() { | |||||
| TypedValueRef<NodeValue> value = | |||||
| imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>(); | |||||
| auto* node = value->node(); | |||||
| return py::cast(node).release().ptr(); | |||||
| } | |||||
| PyObject* TensorWrapper::_graph() { | |||||
| TypedValueRef<NodeValue> value = | |||||
| imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>(); | |||||
| auto* graph = value->graph(); | |||||
| return py::cast(graph).release().ptr(); | |||||
| } | |||||
| struct TensorWeakRef { | struct TensorWeakRef { | ||||
| ValueWeakRef data; | ValueWeakRef data; | ||||
| @@ -807,6 +743,10 @@ void init_tensor(py::module m) { | |||||
| .register_at<Segment::Scalar>( | .register_at<Segment::Scalar>( | ||||
| std::make_shared<ScalarTransformation>()) | std::make_shared<ScalarTransformation>()) | ||||
| .release()); | .release()); | ||||
| MGB_MARK_USED_VAR(transformations | |||||
| .register_at<Segment::Symbol>( | |||||
| std::make_shared<SymbolTransformation>()) | |||||
| .release()); | |||||
| MGB_MARK_USED_VAR(transformations | MGB_MARK_USED_VAR(transformations | ||||
| .register_at<Segment::DTypePromote>( | .register_at<Segment::DTypePromote>( | ||||
| std::make_shared<DTypePromoteTransformation>()) | std::make_shared<DTypePromoteTransformation>()) | ||||
| @@ -863,6 +803,8 @@ void init_tensor(py::module m) { | |||||
| .def<&TensorWrapper::_detail>("_detail") | .def<&TensorWrapper::_detail>("_detail") | ||||
| .def<&TensorWrapper::_set_name>("_set_name") | .def<&TensorWrapper::_set_name>("_set_name") | ||||
| .def<&TensorWrapper::_watch>("_watch") | .def<&TensorWrapper::_watch>("_watch") | ||||
| .def<&TensorWrapper::_var>("var") | |||||
| .def<&TensorWrapper::_graph>("graph") | |||||
| .def_getset< | .def_getset< | ||||
| &TensorWrapper::module_trace_info, | &TensorWrapper::module_trace_info, | ||||
| &TensorWrapper::set_module_trace_info>("_NodeMixin__node") | &TensorWrapper::set_module_trace_info>("_NodeMixin__node") | ||||
| @@ -875,43 +817,6 @@ void init_tensor(py::module m) { | |||||
| .def(py::init<const TensorWrapper&>()) | .def(py::init<const TensorWrapper&>()) | ||||
| .def("__call__", &TensorWeakRef::operator()); | .def("__call__", &TensorWeakRef::operator()); | ||||
| py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") | |||||
| .def_property_readonly( | |||||
| "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) | |||||
| .def_property( | |||||
| "var", [](PySymbolVar* v) { return v->m_node; }, | |||||
| [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) | |||||
| .def_property_readonly( | |||||
| "device", [](PySymbolVar* v) { return v->m_node->comp_node(); }) | |||||
| .def_property_readonly( | |||||
| "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); }) | |||||
| .def_property_readonly( | |||||
| "shape", | |||||
| [](PySymbolVar* v) -> const TensorShape* { | |||||
| auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||||
| return mgr.infer_shape_fallible(v->m_node); | |||||
| }) | |||||
| .def("numpy", | |||||
| [](PySymbolVar* v) { | |||||
| auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||||
| auto&& type = mgr.get_infer_type(v->m_node); | |||||
| using InferType = cg::static_infer::InferType; | |||||
| if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||||
| throw py::value_error("value invalid!"); | |||||
| } | |||||
| auto* val = mgr.infer_value_fallible(v->m_node); | |||||
| if (!val) { | |||||
| throw py::value_error("value invalid!"); | |||||
| } | |||||
| auto np_val = py::cast(*val).attr("numpy")(); | |||||
| return np_val; | |||||
| }) | |||||
| .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||||
| .def(py::init([](cg::VarNode* node) { | |||||
| return std::make_shared<PySymbolVar>(node); | |||||
| }), | |||||
| py::arg() = nullptr); | |||||
| static PyMethodDef method_defs[] = { | static PyMethodDef method_defs[] = { | ||||
| MGE_PY_INTERFACE(apply, py_apply), | MGE_PY_INTERFACE(apply, py_apply), | ||||
| MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | ||||
| @@ -1027,6 +932,10 @@ void init_tensor(py::module m) { | |||||
| py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); | py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); | ||||
| }); | }); | ||||
| m.def("set_py_varnode_type", [](py::object type_obj) { | |||||
| py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); | |||||
| }); | |||||
| m.def("set_py_device_type", | m.def("set_py_device_type", | ||||
| [](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); | [](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); | ||||
| @@ -1217,31 +1126,6 @@ void init_tensor(py::module m) { | |||||
| } | } | ||||
| }); | }); | ||||
| m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object { | |||||
| auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) { | |||||
| auto make_scalar_shape = [&](CompNode device) { | |||||
| return imperative::apply( | |||||
| CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||||
| HostStorage::make(device))[0]; | |||||
| }; | |||||
| return imperative::apply(op, input, make_scalar_shape(*input.device()))[0]; | |||||
| }; | |||||
| if (py::isinstance<PySymbolVar>(tensor)) { | |||||
| auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph(); | |||||
| SymbolVarContext context(graph); | |||||
| context.init(); | |||||
| auto output = reduce_to_scalar( | |||||
| *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor)); | |||||
| auto typeobj = tensor.get_type(); | |||||
| return context.val2symvar(typeobj, output); | |||||
| } else { | |||||
| auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||||
| auto output = reduce_to_scalar( | |||||
| *op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data()); | |||||
| return TensorWrapper::make(py_tensor_type, output); | |||||
| } | |||||
| }); | |||||
| m.def("name_tensor", [](std::string name, py::object tensor) { | m.def("name_tensor", [](std::string name, py::object tensor) { | ||||
| auto* tw = TensorWrapper::try_cast(tensor.ptr()); | auto* tw = TensorWrapper::try_cast(tensor.ptr()); | ||||
| auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; | auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; | ||||
| @@ -10,6 +10,8 @@ | |||||
| #include "./pyext17.h" | #include "./pyext17.h" | ||||
| #include "megbrain/imperative/dispatch.h" | #include "megbrain/imperative/dispatch.h" | ||||
| #include "megbrain/imperative/transformations/scalar.h" | |||||
| #include "megbrain/imperative/transformations/symbol.h" | |||||
| #include "megbrain/imperative/utils/span.h" | #include "megbrain/imperative/utils/span.h" | ||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| @@ -27,6 +29,7 @@ namespace mgb::imperative::python { | |||||
| extern interpreter::Interpreter::Channel* interpreter_for_py; | extern interpreter::Interpreter::Channel* interpreter_for_py; | ||||
| extern PyTypeObject* py_tensor_type; | extern PyTypeObject* py_tensor_type; | ||||
| extern PyTypeObject* py_varnode_type; | |||||
| extern pybind11::handle py_device_type; | extern pybind11::handle py_device_type; | ||||
| extern PyObject* cpp_use_symbolic_shape; | extern PyObject* cpp_use_symbolic_shape; | ||||
| extern PyObject* cpp_astensor1d; | extern PyObject* cpp_astensor1d; | ||||
| @@ -126,16 +129,11 @@ public: | |||||
| void set_module_trace_info(PyObject*); | void set_module_trace_info(PyObject*); | ||||
| void _set_name(PyObject*); | void _set_name(PyObject*); | ||||
| PyObject* _detail(); | PyObject* _detail(); | ||||
| PyObject* _var(); | |||||
| PyObject* _graph(); | |||||
| void _watch(); | void _watch(); | ||||
| }; | }; | ||||
| struct PySymbolVar { | |||||
| cg::VarNode* m_node = nullptr; | |||||
| bool is_scalar = false; | |||||
| PySymbolVar() = default; | |||||
| PySymbolVar(VarNode* m) : m_node(m) {} | |||||
| }; | |||||
| PyObject* py_apply( | PyObject* py_apply( | ||||
| PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); | PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); | ||||
| @@ -146,15 +146,6 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||||
| mgb::DType type = var->m_node->dtype(); | |||||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
| Py_INCREF(descr.get()); | |||||
| tensors.emplace_back(descr.get()); | |||||
| continue; | |||||
| } | |||||
| PyArray_Descr* descr = scalar2dtype(handle); | PyArray_Descr* descr = scalar2dtype(handle); | ||||
| if (descr) { | if (descr) { | ||||
| scalars.emplace_back(descr); | scalars.emplace_back(descr); | ||||
| @@ -204,17 +195,12 @@ CompNode _get_device(PyObject* const* args, size_t nargs) { | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | ||||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | TensorWrapper* tw = TensorWrapper::try_cast(handle); | ||||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||||
| if (tw || is_symvar) { | |||||
| if (tw) { | |||||
| if (!valid) { | if (!valid) { | ||||
| cn = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||||
| cn = tw->m_tensor->comp_node(); | |||||
| valid = true; | valid = true; | ||||
| } else { | } else { | ||||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle) | |||||
| .cast<PySymbolVar*>() | |||||
| ->m_node->comp_node(); | |||||
| CompNode cn1 = tw->m_tensor->comp_node(); | |||||
| if (cn1 != cn) { | if (cn1 != cn) { | ||||
| throw py::value_error(ssprintf( | throw py::value_error(ssprintf( | ||||
| "ambiguous device: %s (from %s) vs %s (from %s)", | "ambiguous device: %s (from %s) vs %s (from %s)", | ||||
| @@ -258,10 +244,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| } | } | ||||
| bool is_scalar(PyObject* tensor) { | bool is_scalar(PyObject* tensor) { | ||||
| if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | |||||
| auto var = py::handle(tensor).cast<PySymbolVar*>(); | |||||
| return var->is_scalar; | |||||
| } | |||||
| auto* tw = TensorWrapper::try_cast(tensor); | auto* tw = TensorWrapper::try_cast(tensor); | ||||
| if (tw) { | if (tw) { | ||||
| return tw->m_tensor->is_scalar(); | return tw->m_tensor->is_scalar(); | ||||
| @@ -319,8 +301,7 @@ py::object device2obj(py::handle device, bool mapping = false) { | |||||
| } | } | ||||
| } | } | ||||
| py::object _Const( | |||||
| py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { | |||||
| py::object _Const(py::handle value, py::handle dtype, py::handle device) { | |||||
| py::object val = py::reinterpret_borrow<py::object>(value); | py::object val = py::reinterpret_borrow<py::object>(value); | ||||
| if (PyArray_Check(value.ptr())) { | if (PyArray_Check(value.ptr())) { | ||||
| py::tuple strides = | py::tuple strides = | ||||
| @@ -338,32 +319,6 @@ py::object _Const( | |||||
| val = val.attr("reshape")(orig_shp); | val = val.attr("reshape")(orig_shp); | ||||
| } | } | ||||
| } | } | ||||
| py::object ref; | |||||
| if (py::isinstance<py::tuple>(ref_hdl)) { | |||||
| py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl); | |||||
| if (tup.size()) { | |||||
| ref = tup[0]; | |||||
| } else { | |||||
| ref = py::none(); | |||||
| } | |||||
| } else { | |||||
| ref = py::reinterpret_borrow<py::object>(ref_hdl); | |||||
| } | |||||
| if (py::isinstance<PySymbolVar>(ref)) { | |||||
| auto ref_var = ref.cast<PySymbolVar*>(); | |||||
| auto* graph = ref_var->m_node->owner_graph(); | |||||
| CompNode cn; | |||||
| if (device.ptr() == Py_None) { | |||||
| cn = ref_var->m_node->comp_node(); | |||||
| } else { | |||||
| cn = device2obj(device).cast<CompNode>(); | |||||
| } | |||||
| OperatorNodeConfig config(cn); | |||||
| auto hv = npy::np2tensor( | |||||
| val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||||
| auto typeobj = ref.get_type(); | |||||
| return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||||
| } | |||||
| py::object device_obj = device2obj(device, true); | py::object device_obj = device2obj(device, true); | ||||
| py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | ||||
| return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | ||||
| @@ -373,7 +328,7 @@ py::tuple _make_shape_tuple(py::handle shape) { | |||||
| py::list orig; | py::list orig; | ||||
| py::list ret(0); | py::list ret(0); | ||||
| auto solve_one = [&](py::handle val) { | auto solve_one = [&](py::handle val) { | ||||
| if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) { | |||||
| if (TensorWrapper::try_cast(val.ptr())) { | |||||
| py::object np = getattr(val, "numpy")(); | py::object np = getattr(val, "numpy")(); | ||||
| PyArrayObject* arr = (PyArrayObject*)np.ptr(); | PyArrayObject* arr = (PyArrayObject*)np.ptr(); | ||||
| PyObject* maybe_list = PyArray_ToList(arr); | PyObject* maybe_list = PyArray_ToList(arr); | ||||
| @@ -415,25 +370,53 @@ py::tuple _make_shape_tuple(py::handle shape) { | |||||
| return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr())); | return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr())); | ||||
| } | } | ||||
| bool is_tensor_or_symbolvar(py::handle arg) { | |||||
| return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg); | |||||
| bool is_tensor(py::handle arg) { | |||||
| return bool(TensorWrapper::try_cast(arg.ptr())); | |||||
| } | } | ||||
| bool is_py_sequence(py::handle arg) { | bool is_py_sequence(py::handle arg) { | ||||
| if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) || | |||||
| py::isinstance<PySymbolVar>(arg)) { | |||||
| if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr())) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| return PySequence_Check(arg.ptr()); | return PySequence_Check(arg.ptr()); | ||||
| } | } | ||||
| mgb::DType _get_dtype(py::handle tensor) { | |||||
| if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { | |||||
| return tw->m_tensor->dtype(); | |||||
| py::object get_res_by_refhdl( | |||||
| py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { | |||||
| py::object res = _Const(value, dtype, device); | |||||
| py::object ref; | |||||
| if (py::isinstance<py::tuple>(ref_hdl)) { | |||||
| py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl); | |||||
| if (tup.size()) { | |||||
| ref = tup[0]; | |||||
| } else { | |||||
| ref = py::none(); | |||||
| } | |||||
| } else { | } else { | ||||
| auto var = tensor.cast<PySymbolVar*>(); | |||||
| return var->m_node->dtype(); | |||||
| ref = py::reinterpret_borrow<py::object>(ref_hdl); | |||||
| } | |||||
| if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) { | |||||
| auto temp = dtype.cast<mgb::DType>(); | |||||
| ComputingGraph* graph = getattr(ref, "graph").cast<ComputingGraph*>(); | |||||
| cg::VarNode* node = getattr(ref, "var").cast<cg::VarNode*>(); | |||||
| CompNode cn; | |||||
| if (device.ptr() == Py_None) { | |||||
| cn = node->comp_node(); | |||||
| } else { | |||||
| cn = device2obj(device).cast<CompNode>(); | |||||
| } | |||||
| OperatorNodeConfig config(cn); | |||||
| auto hv = npy::np2tensor( | |||||
| value.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||||
| auto typeobj = ref.get_type(); | |||||
| return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||||
| } | } | ||||
| return res; | |||||
| } | |||||
| mgb::DType _get_dtype(py::handle tensor) { | |||||
| auto tw = TensorWrapper::try_cast(tensor.ptr()); | |||||
| return tw->m_tensor->dtype(); | |||||
| } | } | ||||
| py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | ||||
| @@ -457,12 +440,12 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | |||||
| py::object _convert_single_value_cpp( | py::object _convert_single_value_cpp( | ||||
| py::handle value, py::handle dtype, py::handle device) { | py::handle value, py::handle dtype, py::handle device) { | ||||
| if (is_tensor_or_symbolvar(value)) { | |||||
| if (is_tensor(value)) { | |||||
| if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { | if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { | ||||
| return _astype_cpp(value, dtype); | return _astype_cpp(value, dtype); | ||||
| } | } | ||||
| } else { | } else { | ||||
| return _Const(value, dtype, device, py::none()); | |||||
| return _Const(value, dtype, device); | |||||
| } | } | ||||
| return py::reinterpret_borrow<py::object>(value); | return py::reinterpret_borrow<py::object>(value); | ||||
| } | } | ||||
| @@ -475,28 +458,8 @@ py::object _convert_inputs_cpp( | |||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| py::handle h = py::handle(args[i]); | py::handle h = py::handle(args[i]); | ||||
| lis.append(h); | lis.append(h); | ||||
| if (py::isinstance<PySymbolVar>(h)) { | |||||
| auto var = h.cast<PySymbolVar*>(); | |||||
| auto g = var->m_node->owner_graph(); | |||||
| if (!graph) { | |||||
| graph = g; | |||||
| typeobj = h.get_type(); | |||||
| } else { | |||||
| mgb_assert(graph == g); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (graph) { | |||||
| CompNode cn = device2obj(device).cast<CompNode>(); | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| OperatorNodeConfig config(cn); | |||||
| auto hv = npy::np2tensor( | |||||
| lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||||
| if (!py::isinstance<PySymbolVar>(lis[i])) { | |||||
| lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| auto convert = [&](py::object value) { | auto convert = [&](py::object value) { | ||||
| if (value.is_none()) { | if (value.is_none()) { | ||||
| return value; | return value; | ||||
| @@ -517,7 +480,8 @@ py::object _astensor1d_cpp( | |||||
| if (device.ptr() != Py_None) { | if (device.ptr() != Py_None) { | ||||
| device_obj = device2obj(device); | device_obj = device2obj(device); | ||||
| } | } | ||||
| if (py::isinstance<PySymbolVar>(value)) { | |||||
| if (PyObject_TypeCheck(value.ptr(), py_varnode_type)) { | |||||
| try { | try { | ||||
| getattr(value, "ndim"); | getattr(value, "ndim"); | ||||
| } catch (py::error_already_set& err) { | } catch (py::error_already_set& err) { | ||||
| @@ -537,14 +501,15 @@ py::object _astensor1d_cpp( | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } | } | ||||
| size_t ndim = 999; | size_t ndim = 999; | ||||
| if (hasattr(value, "ndim")) { | if (hasattr(value, "ndim")) { | ||||
| ndim = getattr(value, "ndim").cast<size_t>(); | ndim = getattr(value, "ndim").cast<size_t>(); | ||||
| if (ndim != 0 && ndim != 1) { | if (ndim != 0 && ndim != 1) { | ||||
| throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim)); | throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim)); | ||||
| } | } | ||||
| if (!is_tensor_or_symbolvar(value)) { | |||||
| return _Const(value, dtype, device, ref); | |||||
| if (!is_tensor(value)) { | |||||
| return get_res_by_refhdl(value, dtype, device, ref); | |||||
| } else { | } else { | ||||
| return py::reinterpret_borrow<py::object>(value); | return py::reinterpret_borrow<py::object>(value); | ||||
| } | } | ||||
| @@ -555,13 +520,13 @@ py::object _astensor1d_cpp( | |||||
| py::list lis = py::reinterpret_steal<py::list>(PySequence_List(value.ptr())); | py::list lis = py::reinterpret_steal<py::list>(PySequence_List(value.ptr())); | ||||
| bool need_concat = false; | bool need_concat = false; | ||||
| for (size_t i = 0; i < lis.size(); ++i) { | for (size_t i = 0; i < lis.size(); ++i) { | ||||
| if (is_tensor_or_symbolvar(lis[i])) { | |||||
| if (is_tensor(lis[i])) { | |||||
| need_concat = true; | need_concat = true; | ||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| if (!need_concat) { | if (!need_concat) { | ||||
| return _Const(value, dtype, device, ref); | |||||
| return get_res_by_refhdl(value, dtype, device, ref); | |||||
| } | } | ||||
| if (lis.size() > 1) { | if (lis.size() > 1) { | ||||
| std::vector<PyObject*> c_args(lis.size() + 1); | std::vector<PyObject*> c_args(lis.size() + 1); | ||||
| @@ -600,10 +565,9 @@ py::object _astensor1d_cpp( | |||||
| } | } | ||||
| py::object _get_index(py::object tensor, py::object src) { | py::object _get_index(py::object tensor, py::object src) { | ||||
| if (!TensorWrapper::try_cast(tensor.ptr()) && | |||||
| !py::isinstance<PySymbolVar>(tensor)) { | |||||
| if (!TensorWrapper::try_cast(tensor.ptr())) { | |||||
| auto get_const = [&](mgb::DType dtype) -> py::object { | auto get_const = [&](mgb::DType dtype) -> py::object { | ||||
| return _Const(tensor, py::cast(dtype), src.attr("device"), src); | |||||
| return _Const(tensor, py::cast(dtype), src.attr("device")); | |||||
| }; | }; | ||||
| if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) { | if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) { | ||||
| tensor = get_const(dtype::Bool()); | tensor = get_const(dtype::Bool()); | ||||
| @@ -636,9 +600,8 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) { | |||||
| } | } | ||||
| py::object iobj; | py::object iobj; | ||||
| if (PyArray_Check(index.ptr())) { | if (PyArray_Check(index.ptr())) { | ||||
| iobj = | |||||
| _Const(index, py::cast((mgb::DType)dtype::Bool()), | |||||
| getattr(tensor, "device"), tensor); | |||||
| iobj = _Const( | |||||
| index, py::cast((mgb::DType)dtype::Bool()), getattr(tensor, "device")); | |||||
| } else { | } else { | ||||
| iobj = py::reinterpret_borrow<py::object>(index); | iobj = py::reinterpret_borrow<py::object>(index); | ||||
| } | } | ||||
| @@ -920,8 +883,8 @@ py::object _expand_args(py::handle args) { | |||||
| return py::reinterpret_borrow<py::object>(args); | return py::reinterpret_borrow<py::object>(args); | ||||
| } | } | ||||
| py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | ||||
| if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||||
| is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||||
| if (args_tup.size() == 1 && | |||||
| (PySequence_Check(args_tup[0].ptr()) || is_tensor(args_tup[0].ptr()))) { | |||||
| return py::reinterpret_borrow<py::object>(args_tup[0]); | return py::reinterpret_borrow<py::object>(args_tup[0]); | ||||
| } else { | } else { | ||||
| return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | ||||
| @@ -948,7 +911,8 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { | |||||
| bool enable_fastpath(py::handle inp) { | bool enable_fastpath(py::handle inp) { | ||||
| auto&& tm_tr = TransformationManager::get_instance() | auto&& tm_tr = TransformationManager::get_instance() | ||||
| .segments[TransformationManager::Segment::ModuleTrace]; | .segments[TransformationManager::Segment::ModuleTrace]; | ||||
| if (!TensorWrapper::try_cast(inp.ptr()) || | |||||
| bool is_varnode = PyObject_TypeCheck(inp.ptr(), py_varnode_type); | |||||
| if (is_varnode || | |||||
| TransformationManager::get_instance() | TransformationManager::get_instance() | ||||
| .segments[TransformationManager::Segment::Trace] | .segments[TransformationManager::Segment::Trace] | ||||
| .size() > 0 || | .size() > 0 || | ||||
| @@ -1181,10 +1145,8 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { | |||||
| py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) { | py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) { | ||||
| py::object org_shape = getattr(inp_hdl, "shape"); | py::object org_shape = getattr(inp_hdl, "shape"); | ||||
| py::object val = py::reinterpret_borrow<py::object>(val_hdl); | py::object val = py::reinterpret_borrow<py::object>(val_hdl); | ||||
| if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) { | |||||
| val = | |||||
| _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"), | |||||
| inp_hdl); | |||||
| if (!TensorWrapper::try_cast(val.ptr())) { | |||||
| val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device")); | |||||
| } | } | ||||
| py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); | py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); | ||||
| @@ -1308,12 +1270,12 @@ py::object _split_cpp( | |||||
| repr(nsplits_or_sections_hdl).cast<std::string>()); | repr(nsplits_or_sections_hdl).cast<std::string>()); | ||||
| } | } | ||||
| py::object pos = div_points[i] - div_points[i - 1]; | py::object pos = div_points[i] - div_points[i - 1]; | ||||
| if (is_tensor_or_symbolvar(pos)) { | |||||
| if (is_tensor(pos)) { | |||||
| partitions.append(pos); | partitions.append(pos); | ||||
| } else { | } else { | ||||
| partitions.append( | partitions.append( | ||||
| _Const(pos, py::cast((mgb::DType)dtype::Int32()), | _Const(pos, py::cast((mgb::DType)dtype::Int32()), | ||||
| getattr(inp_hdl, "device"), inp_hdl)); | |||||
| getattr(inp_hdl, "device"))); | |||||
| } | } | ||||
| } | } | ||||
| op = Split::make(axis, 0); | op = Split::make(axis, 0); | ||||
| @@ -1438,7 +1400,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||||
| py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | ||||
| py::object obj = _expand_args(args); | py::object obj = _expand_args(args); | ||||
| py::list lis; | py::list lis; | ||||
| if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { | |||||
| if (!is_tensor(obj.ptr()) && PySequence_Check(obj.ptr())) { | |||||
| lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr())); | lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr())); | ||||
| } else { | } else { | ||||
| py::object np = getattr(obj, "numpy")(); | py::object np = getattr(obj, "numpy")(); | ||||
| @@ -1631,7 +1593,7 @@ PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) | |||||
| PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
| try { | try { | ||||
| return _Const(args[0], args[1], args[2], args[3]).release().ptr(); | |||||
| return _Const(args[0], args[1], args[2]).release().ptr(); | |||||
| } | } | ||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
| } | } | ||||
| @@ -1696,4 +1658,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
| } | } | ||||
| } // namespace mgb::imperative::python | |||||
| } // namespace mgb::imperative::python | |||||
| @@ -20,11 +20,12 @@ public: | |||||
| DimExpansion, | DimExpansion, | ||||
| Grad, | Grad, | ||||
| Scalar, | Scalar, | ||||
| Symbol, | |||||
| Trace, | Trace, | ||||
| Eval, | Eval, | ||||
| }; | }; | ||||
| std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | |||||
| std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments; | |||||
| private: | private: | ||||
| template <Segment segment> | template <Segment segment> | ||||
| @@ -11,7 +11,7 @@ from megengine.utils.network_node import VarNode | |||||
| def _default_compare_fn(x, y): | def _default_compare_fn(x, y): | ||||
| if isinstance(x, tensor): | |||||
| if isinstance(x, tensor) and not isinstance(x, VarNode): | |||||
| x = x.numpy() | x = x.numpy() | ||||
| elif not isinstance(x, np.ndarray): | elif not isinstance(x, np.ndarray): | ||||
| x = get_var_value(x) | x = get_var_value(x) | ||||
| @@ -679,6 +679,18 @@ def test_utils_astensor1d(is_varnode): | |||||
| assert isinstance(xx, type(reference)) | assert isinstance(xx, type(reference)) | ||||
| np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | ||||
| # varnode | |||||
| if is_varnode: | |||||
| a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32") | |||||
| b = np.array([[True, False, True], [False, True, True]]) | |||||
| aa = make_tensor(a, network) | |||||
| bb = make_tensor(b, network) | |||||
| x, y = F.cond_take(bb, aa) | |||||
| for dtype in [None, "float32"]: | |||||
| xx = astensor1d(x, reference, dtype=dtype) | |||||
| assert isinstance(xx, type(reference)) | |||||
| np.testing.assert_equal(get_var_value(xx), get_var_value(x)) | |||||
| def test_device(): | def test_device(): | ||||
| x = tensor([1, 2, 3], dtype="float32") | x = tensor([1, 2, 3], dtype="float32") | ||||
| @@ -114,8 +114,10 @@ def test_replace_opr(): | |||||
| vara = graph.var_filter.name("a").as_unique() | vara = graph.var_filter.name("a").as_unique() | ||||
| varb = graph.var_filter.name("b").as_unique() | varb = graph.var_filter.name("b").as_unique() | ||||
| out1 = F.sub(vara, varb) | |||||
| out1 = F.mul(vara, varb) | |||||
| out1 = F.relu(out1) | out1 = F.relu(out1) | ||||
| out1 += 2 | |||||
| out1 *= 3 | |||||
| out1 = graph.add_dep_oprs(out1) | out1 = graph.add_dep_oprs(out1) | ||||
| orig_opr = graph.opr_filter.has_input(vara).as_unique() | orig_opr = graph.opr_filter.has_input(vara).as_unique() | ||||
| @@ -135,7 +137,7 @@ def test_replace_opr(): | |||||
| load_graph = GraphInference(modified_model1) | load_graph = GraphInference(modified_model1) | ||||
| out = load_graph.run(a, b) | out = load_graph.run(a, b) | ||||
| np.testing.assert_equal(out["o"], [0, 0]) | |||||
| np.testing.assert_equal(out["o"], [30, 60]) | |||||
| def test_splice_network(): | def test_splice_network(): | ||||
| @@ -82,6 +82,10 @@ std::string DTRCommand::to_string() const { | |||||
| return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind); | return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind); | ||||
| } | } | ||||
| std::string CreateNode::to_string() const { | |||||
| return "CreateNode"; | |||||
| } | |||||
| std::string GetName::to_string() const { | std::string GetName::to_string() const { | ||||
| return "GetName{}"; | return "GetName{}"; | ||||
| } | } | ||||
| @@ -94,5 +98,9 @@ std::string IsScalar::to_string() const { | |||||
| return "IsScalar"; | return "IsScalar"; | ||||
| } | } | ||||
| std::string GetVarVal::to_string() const { | |||||
| return "GetVarVal"; | |||||
| } | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -157,5 +157,22 @@ public: | |||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| }; | }; | ||||
| class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { | |||||
| public: | |||||
| std::string to_string() const override; | |||||
| }; | |||||
| class CreateNode final : public OperatorImpl<CreateNode> { | |||||
| private: | |||||
| cg::VarNode* m_node; | |||||
| public: | |||||
| CreateNode(cg::VarNode* node) : m_node(node) {} | |||||
| cg::VarNode* node() const { return m_node; } | |||||
| std::string to_string() const override; | |||||
| }; | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -173,5 +173,24 @@ public: | |||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| }; | }; | ||||
| class NodeStorage { | |||||
| private: | |||||
| cg::VarNode* m_node; | |||||
| public: | |||||
| NodeStorage() = default; | |||||
| NodeStorage(VarNode* node) : m_node(node) {} | |||||
| VarNode* node() const { return m_node; } | |||||
| ComputingGraph* graph() const { return m_node->owner_graph(); } | |||||
| std::string to_string() const { return m_node->name(); } | |||||
| }; | |||||
| class NodeValue final : public PrimitiveValue<NodeValue, NodeStorage> { | |||||
| public: | |||||
| using PrimitiveValue::PrimitiveValue; | |||||
| std::string to_string() const override { return NodeStorage::to_string(); } | |||||
| }; | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -39,13 +39,39 @@ private: | |||||
| ObjectType<SymbolValue> m_value_type{"SymbolValue"}; | ObjectType<SymbolValue> m_value_type{"SymbolValue"}; | ||||
| public: | public: | ||||
| SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} | |||||
| SymbolTransformation() {} | |||||
| ValueRefList apply_transformation( | ValueRefList apply_transformation( | ||||
| const Operator& op, Span<ValueRef> inputs) override { | const Operator& op, Span<ValueRef> inputs) override { | ||||
| ComputingGraph* cg = nullptr; | |||||
| if (auto* node_value = op.as<CreateNode>()) { | |||||
| return {m_value_type.make(node_value->node())}; | |||||
| } | |||||
| for (auto&& input : inputs) { | |||||
| if (auto* val = input.as(m_value_type)) { | |||||
| auto* node = val->node(); | |||||
| ComputingGraph* cur_cg = node->owner_graph(); | |||||
| if (cg == nullptr) { | |||||
| cg = cur_cg; | |||||
| } else { | |||||
| mgb_assert(cg == cur_cg, "input varnode gragh should be the same"); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!cg) { | |||||
| return imperative::apply(op, inputs); | |||||
| } | |||||
| if (auto* apply_op = op.as<ApplyOp>()) { | if (auto* apply_op = op.as<ApplyOp>()) { | ||||
| SmallVector<VarNode*> input_nodes; | SmallVector<VarNode*> input_nodes; | ||||
| for (auto&& input : inputs) { | for (auto&& input : inputs) { | ||||
| input_nodes.push_back(input.cast(m_value_type).node()); | |||||
| if (!input.is(m_value_type)) { | |||||
| auto* node = opr::ImmutableTensor::make( | |||||
| *cg, input.numpy()->as_nd(true), {}) | |||||
| .node(); | |||||
| input_nodes.push_back(node); | |||||
| } else { | |||||
| input_nodes.push_back(input.cast(m_value_type).node()); | |||||
| } | |||||
| } | } | ||||
| auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | ||||
| ValueRefList outputs(output_nodes.size()); | ValueRefList outputs(output_nodes.size()); | ||||
| @@ -53,15 +79,9 @@ public: | |||||
| outputs[i] = m_value_type.make(output_nodes[i]); | outputs[i] = m_value_type.make(output_nodes[i]); | ||||
| } | } | ||||
| return outputs; | return outputs; | ||||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||||
| auto&& args = create_tensor->parse(inputs); | |||||
| mgb_assert( | |||||
| args.kind == CreateTensor::Const, | |||||
| "only const value is allowed here"); | |||||
| auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); | |||||
| return {m_value_type.make(node)}; | |||||
| } else if (auto* get_attr = op.as<GetAttr>()) { | } else if (auto* get_attr = op.as<GetAttr>()) { | ||||
| auto* node = inputs.item().cast(m_value_type).node(); | auto* node = inputs.item().cast(m_value_type).node(); | ||||
| auto* m_graph = node->owner_graph(); | |||||
| switch (get_attr->attr()) { | switch (get_attr->attr()) { | ||||
| case GetAttr::DType: | case GetAttr::DType: | ||||
| return {DTypeValue::make(node->dtype())}; | return {DTypeValue::make(node->dtype())}; | ||||
| @@ -105,6 +125,10 @@ public: | |||||
| MegBrainError, "Symbol: malformed GetAttr: %s", | MegBrainError, "Symbol: malformed GetAttr: %s", | ||||
| op.to_string().c_str()); | op.to_string().c_str()); | ||||
| } | } | ||||
| } else if (auto* get_attr = op.as<GetVarVal>()) { | |||||
| cg::VarNode* node = inputs.item().cast(m_value_type).node(); | |||||
| NodeStorage inp_var = NodeStorage(node); | |||||
| return {NodeValue::make(inp_var)}; | |||||
| } else { | } else { | ||||
| return op.fallback(inputs); | return op.fallback(inputs); | ||||
| } | } | ||||
| @@ -33,6 +33,7 @@ class ShapeValue; | |||||
| class DTypeValue; | class DTypeValue; | ||||
| class CompNodeValue; | class CompNodeValue; | ||||
| class StringValue; | class StringValue; | ||||
| class NodeValue; | |||||
| class Operator; | class Operator; | ||||