GitOrigin-RevId: 16da6d1491
tags/v1.10.0
| @@ -7,9 +7,7 @@ from typing import Union | |||
| import numpy as np | |||
| from .. import _config | |||
| from .._imperative_rt.common import CompNode | |||
| from .._imperative_rt.core2 import ( | |||
| SymbolVar, | |||
| Tensor, | |||
| apply, | |||
| astype_cpp, | |||
| @@ -17,9 +15,11 @@ from .._imperative_rt.core2 import ( | |||
| broadcast_cpp, | |||
| getitem_cpp, | |||
| matmul_cpp, | |||
| reshape_cpp, | |||
| setitem_cpp, | |||
| squeeze_cpp, | |||
| transpose_cpp, | |||
| ) | |||
| from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | |||
| from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp | |||
| from ..ops import builtin | |||
| from . import amp | |||
| from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | |||
| @@ -189,9 +189,7 @@ def _todo(*_): | |||
| def _expand_args(args): | |||
| if len(args) == 1: | |||
| if isinstance( | |||
| args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), | |||
| ): | |||
| if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||
| args = args[0] | |||
| return args | |||
| @@ -8,7 +8,6 @@ import numpy as np | |||
| from .._imperative_rt import make_const | |||
| from .._imperative_rt.core2 import ( | |||
| Const, | |||
| SymbolVar, | |||
| Tensor, | |||
| _get_convert_inputs, | |||
| _set_convert_inputs, | |||
| @@ -77,7 +76,7 @@ def result_type(*args): | |||
| def isscalar(x): | |||
| if isinstance(x, (Tensor, SymbolVar)): | |||
| if isinstance(x, Tensor): | |||
| return x._isscalar() | |||
| return np.isscalar(x) | |||
| @@ -283,7 +282,7 @@ def interpret_subgraph(func, dtype, device): | |||
| return results | |||
| def apply_const(value, dtype=dtype, device=device): | |||
| return Const(value, dtype, device, None) | |||
| return Const(value, dtype, device) | |||
| outputs, outputs_has_grad = func(args, apply_expr, apply_const) | |||
| outputs = [ | |||
| @@ -2,7 +2,7 @@ | |||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor.array_method import _elwise | |||
| @@ -538,7 +538,7 @@ def topk( | |||
| op = builtin.TopK(mode=mode) | |||
| if not isinstance(k, Tensor): | |||
| k = Const(k, "int32", inp.device, None) | |||
| k = Const(k, "int32", inp.device) | |||
| if len(inp.shape) == 1: | |||
| if kth_only: | |||
| @@ -1222,7 +1222,7 @@ def batch_norm( | |||
| raise ValueError("Invalid param_dim {}".format(param_dim)) | |||
| if x is None: | |||
| x = Const(value, inp.dtype, inp.device, None) | |||
| x = Const(value, inp.dtype, inp.device) | |||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Broadcast(), x, shape) | |||
| return result | |||
| @@ -1446,7 +1446,7 @@ def sync_batch_norm( | |||
| def _make_full_if_none(x, value): | |||
| if x is None: | |||
| x = Const(value, inp.dtype, _device, None) | |||
| x = Const(value, inp.dtype, _device) | |||
| (result,) = apply(builtin.Broadcast(), x, reduce_shape) | |||
| return result | |||
| elif x.ndim == 1: | |||
| @@ -7,7 +7,6 @@ import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import ( | |||
| Const, | |||
| SymbolVar, | |||
| apply, | |||
| broadcast_cpp, | |||
| dtype_promotion, | |||
| @@ -151,7 +150,7 @@ def full( | |||
| shape = (shape,) | |||
| if device is None: | |||
| device = get_default_device() | |||
| x = Const(value, dtype, device, None) | |||
| x = Const(value, dtype, device) | |||
| if type(shape) in (list, tuple) and len(shape) == 0: | |||
| return x | |||
| return broadcast_to(x, shape) | |||
| @@ -216,7 +215,7 @@ def zeros( | |||
| return full(shape, 0.0, dtype=dtype, device=device) | |||
| def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| def zeros_like(inp: Tensor) -> Tensor: | |||
| r"""Returns a tensor filled with zeros with the same shape and data type as input tensor. | |||
| Args: | |||
| @@ -235,7 +234,7 @@ def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| return full_like(inp, 0.0) | |||
| def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| def ones_like(inp: Tensor) -> Tensor: | |||
| r"""Returns a tensor filled with ones with the same shape and data type as input tensor. | |||
| Args: | |||
| @@ -253,9 +252,7 @@ def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| return full_like(inp, 1.0) | |||
| def full_like( | |||
| inp: Union[Tensor, SymbolVar], value: Union[int, float] | |||
| ) -> Union[Tensor, SymbolVar]: | |||
| def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||
| r"""Returns a tensor filled with given value with the same shape as input tensor. | |||
| Args: | |||
| @@ -272,7 +269,7 @@ def full_like( | |||
| Tensor([[2 2 2] | |||
| [2 2 2]], dtype=int32, device=xpux:0) | |||
| """ | |||
| x = Const(value, inp.dtype, inp.device, inp) | |||
| x = Const(value, inp.dtype, inp.device) | |||
| if inp.ndim == 0: | |||
| return x | |||
| return broadcast_to(x, inp.shape) | |||
| @@ -668,9 +665,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||
| >>> print(v.numpy(), index.numpy()) | |||
| [1. 4.] [0 3] | |||
| """ | |||
| if not isinstance(x, (Tensor, SymbolVar)): | |||
| if not isinstance(x, Tensor): | |||
| raise TypeError("input must be a tensor") | |||
| if not isinstance(mask, (Tensor, SymbolVar)): | |||
| if not isinstance(mask, Tensor): | |||
| raise TypeError("mask must be a tensor") | |||
| if mask.dtype != np.bool_: | |||
| raise ValueError("mask must be bool") | |||
| @@ -843,15 +840,11 @@ def linspace( | |||
| if not (cur_device is None or device == cur_device): | |||
| raise ("ambiguous device for linspace opr") | |||
| is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) | |||
| if any(is_symbolvar) and not all(is_symbolvar): | |||
| raise TypeError("start, stop and num should all be VarNode or none of them") | |||
| if not isinstance(start, (Tensor, SymbolVar)): | |||
| if not isinstance(start, Tensor): | |||
| start = Tensor(start, device=device) | |||
| if not isinstance(stop, (Tensor, SymbolVar)): | |||
| if not isinstance(stop, Tensor): | |||
| stop = Tensor(stop, device=device) | |||
| if not isinstance(num, (Tensor, SymbolVar)): | |||
| if not isinstance(num, Tensor): | |||
| num = Tensor(num, device=device) | |||
| op = builtin.Linspace(comp_node=device) | |||
| @@ -901,9 +894,12 @@ def arange( | |||
| if stop is None: | |||
| start, stop = 0, start | |||
| start = Tensor(start, dtype="float32") | |||
| stop = Tensor(stop, dtype="float32") | |||
| step = Tensor(step, dtype="float32") | |||
| if not isinstance(start, Tensor): | |||
| start = Tensor(start, dtype="float32") | |||
| if not isinstance(stop, Tensor): | |||
| stop = Tensor(stop, dtype="float32") | |||
| if not isinstance(step, Tensor): | |||
| step = Tensor(step, dtype="float32") | |||
| num = ceil((stop - start) / step) | |||
| stop = start + step * (num - 1) | |||
| @@ -7,11 +7,11 @@ small_tensor_cache = {} | |||
| def _get_scalar_tensor_with_value(value, dtype=None, device=None): | |||
| global small_tensor_cache | |||
| if is_tracing(): | |||
| ret = Const(value, dtype, device, None) | |||
| ret = Const(value, dtype, device) | |||
| else: | |||
| cache_key = (value, dtype, device) | |||
| if cache_key not in small_tensor_cache: | |||
| ret = Const(value, dtype, device, None) | |||
| ret = Const(value, dtype, device) | |||
| small_tensor_cache[cache_key] = ret | |||
| else: | |||
| ret = small_tensor_cache[cache_key] | |||
| @@ -154,6 +154,8 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| @name.setter | |||
| def name(self, name): | |||
| self._custom_name = name | |||
| if name == None: | |||
| name = "" | |||
| self._name = self._prefix + "." + name if self._prefix else name | |||
| self._set_name(self._name) | |||
| @@ -756,7 +756,7 @@ class Constant(Expr): | |||
| def interpret(self, *inputs): | |||
| if isinstance(self.value, RawTensor): | |||
| return (Const(self.value.numpy(), None, None, None),) | |||
| return (Const(self.value.numpy(), None, None),) | |||
| return (self.value,) | |||
| def __repr__(self): | |||
| @@ -395,7 +395,7 @@ class Network: | |||
| for ind, var in enumerate(opr.outputs): | |||
| var.owner = repl_dict[opr] | |||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
| var.var = repl_dict[opr].outputs[ind].var | |||
| var._reset_var(repl_dict[opr].outputs[ind].var) | |||
| repl_dict[opr].outputs = opr.outputs | |||
| self._compile() | |||
| @@ -6,11 +6,11 @@ from typing import Sequence | |||
| import numpy as np | |||
| from ..core import _imperative_rt as rt | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core._imperative_rt.core2 import apply, set_py_varnode_type | |||
| from ..core._trace_option import use_symbolic_shape | |||
| from ..core._wrap import Device | |||
| from ..core.ops import builtin | |||
| from ..core.tensor.array_method import ArrayMethodMixin | |||
| from ..tensor import Tensor | |||
| from .comp_graph_tools import replace_vars | |||
| from .module_stats import ( | |||
| preprocess_receptive_field, | |||
| @@ -23,26 +23,72 @@ class NetworkNode: | |||
| pass | |||
| class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||
| pass | |||
| class VarNode(NetworkNode, Tensor): | |||
| _users = None | |||
| _owner = None | |||
| _name = None | |||
| _id = None | |||
| def __new__(cls, var, *, owner_opr=None, name=None): | |||
| obj = Tensor.__new__(cls, var) | |||
| return obj | |||
| class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| def __init__(self, var=None, *, owner_opr=None, name=None): | |||
| SymbolVar.__init__(self, var) | |||
| self.users = [] # List[OpNode] | |||
| self.owner = owner_opr | |||
| def __init__(self, var, *, owner_opr=None, name=None): | |||
| self._owner = owner_opr | |||
| self.name = name | |||
| self.id = id(self) | |||
| @classmethod | |||
| def load(cls, sym_var, owner_opr): | |||
| obj = cls() | |||
| obj = cls(sym_var) | |||
| obj.var = sym_var # mgb varnode | |||
| obj.name = sym_var.name | |||
| obj.owner = owner_opr | |||
| return obj | |||
| @property | |||
| def users(self): | |||
| if self._users is None: | |||
| self._users = [] | |||
| return self._users | |||
| @property | |||
| def owner(self): | |||
| return self._owner | |||
| @owner.setter | |||
| def owner(self, owner): | |||
| self._owner = owner | |||
| @property | |||
| def id(self): | |||
| if self._id is None: | |||
| self._id = id(self) | |||
| return self._id | |||
| @property | |||
| def var(self): | |||
| return super().var() | |||
| @var.setter | |||
| def var(self, var): | |||
| self._reset(var) | |||
| def _reset(self, other): | |||
| if not isinstance(other, Tensor): | |||
| other = VarNode(other) | |||
| super()._reset(other) | |||
| self.owner = None | |||
| def _reset_var(self, var): | |||
| origin_owner = self.owner | |||
| self.var = var | |||
| self.var.name = self.name | |||
| self.owner = origin_owner | |||
| @property | |||
| def graph(self): | |||
| return super().graph() | |||
| def _get_var_shape(self, axis=None): | |||
| opdef = ( | |||
| builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis) | |||
| @@ -77,14 +123,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| return rst | |||
| return self._get_var_shape() if self.var else None | |||
| @property | |||
| def dtype(self): | |||
| return self.var.dtype if self.var else None | |||
| @property | |||
| def ndim(self): | |||
| return super().ndim | |||
| def __bool__(self): | |||
| return False | |||
| @@ -92,27 +130,11 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| __int__ = None | |||
| __float__ = None | |||
| __complex__ = None | |||
| __repr__ = lambda self: "VarNode:" + self.name | |||
| def __hash__(self): | |||
| return id(self) | |||
| def numpy(self): | |||
| return super().numpy() | |||
| def _reset(self, other): | |||
| if not isinstance(other, VarNode): | |||
| assert self.graph, "VarNode _reset must have graph" | |||
| node = ImmutableTensor(other, graph=self.graph) | |||
| node.compile(self.graph) | |||
| other = node.outputs[0] | |||
| if self.owner is not None: | |||
| idx = self.owner.outputs.index(self) | |||
| self.owner.outputs[idx] = VarNode( | |||
| self.var, owner_opr=self.owner, name=self.var.name | |||
| ) | |||
| self.var = other.var | |||
| self.owner = None | |||
| def set_owner_opr(self, owner_opr): | |||
| self.owner = owner_opr | |||
| @@ -158,8 +180,7 @@ class OpNode(NetworkNode): | |||
| assert len(outputs) == len(self.outputs) | |||
| self._opr = outputs[0].owner | |||
| for i in range(len(self.outputs)): | |||
| self.outputs[i].var = outputs[i] | |||
| self.outputs[i].var.name = self.outputs[i].name | |||
| self.outputs[i]._reset_var(outputs[i]) | |||
| assert self.outputs[i].owner is self | |||
| def add_inp_var(self, x): | |||
| @@ -214,8 +235,9 @@ class Host2DeviceCopy(OpNode): | |||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
| self._opr = outputs.owner | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = outputs | |||
| self.outputs.append(VarNode(outputs, owner_opr=self, name=self.name)) | |||
| else: | |||
| self.outputs[0]._reset_var(outputs) | |||
| assert self.outputs[0].owner is self | |||
| @@ -262,8 +284,9 @@ class ConstOpBase(OpNode): | |||
| data = data.astype(np.int32) | |||
| varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = varnode | |||
| self.outputs.append(VarNode(varnode, owner_opr=self, name=self.name)) | |||
| else: | |||
| self.outputs[0]._reset_var(varnode) | |||
| self._opr = varnode.owner | |||
| @classmethod | |||
| @@ -313,7 +336,7 @@ class ReadOnlyOpNode(OpNode): | |||
| if bool(repl_dict): | |||
| out_vars = replace_vars(self._opr.outputs, repl_dict) | |||
| for ind, o in enumerate(self.outputs): | |||
| o.var = out_vars[ind] | |||
| o._reset_var(out_vars[ind]) | |||
| class Elemwise(OpNode): | |||
| @@ -785,3 +808,6 @@ class AssertEqual(OpNode): | |||
| class CvtColorForward(OpNode): | |||
| type = "CvtColor" | |||
| opdef = builtin.CvtColor | |||
| set_py_varnode_type(VarNode) | |||
| @@ -114,6 +114,8 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) { | |||
| } | |||
| } | |||
| py::object Py_Varnode = py::none(); | |||
| void init_graph_rt(py::module m) { | |||
| static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{ | |||
| std::make_unique<mgb::OprFootprint>()}; | |||
| @@ -124,40 +126,44 @@ void init_graph_rt(py::module m) { | |||
| def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); | |||
| py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | |||
| .def_property_readonly( | |||
| "owner", [](cg::VarNode* v) { return v->owner_opr(); }) | |||
| .def_property_readonly( | |||
| "graph", [](cg::VarNode* v) { return v->owner_graph(); }) | |||
| .def_property( | |||
| "name", py::overload_cast<>(&VarNode::name, py::const_), | |||
| py::overload_cast<std::string>(&VarNode::name)) | |||
| .def_property_readonly("dtype", [](cg::VarNode* v) { return v->dtype(); }) | |||
| .def_property_readonly( | |||
| "comp_node", [](cg::VarNode* v) { return v->comp_node(); }) | |||
| .def_property_readonly( | |||
| "shape", | |||
| [](cg::VarNode* v) -> const TensorShape* { | |||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||
| return mgr.infer_shape_fallible(v); | |||
| }) | |||
| .def_property_readonly( | |||
| "value", | |||
| [](cg::VarNode* v) -> py::object { | |||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||
| auto&& type = mgr.get_infer_type(v); | |||
| using InferType = cg::static_infer::InferType; | |||
| if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||
| return py::none(); | |||
| } | |||
| auto* val = mgr.infer_value_fallible(v); | |||
| if (!val) { | |||
| return py::none(); | |||
| } | |||
| return py::cast(*val).attr("numpy")(); | |||
| }) | |||
| .def_property_readonly("id", [](cg::VarNode* v) { return (v->id()); }) | |||
| .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); }); | |||
| Py_Varnode = | |||
| py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | |||
| .def_property_readonly( | |||
| "owner", [](cg::VarNode* v) { return v->owner_opr(); }) | |||
| .def_property_readonly( | |||
| "graph", [](cg::VarNode* v) { return v->owner_graph(); }) | |||
| .def_property( | |||
| "name", py::overload_cast<>(&VarNode::name, py::const_), | |||
| py::overload_cast<std::string>(&VarNode::name)) | |||
| .def_property_readonly( | |||
| "dtype", [](cg::VarNode* v) { return v->dtype(); }) | |||
| .def_property_readonly( | |||
| "comp_node", [](cg::VarNode* v) { return v->comp_node(); }) | |||
| .def_property_readonly( | |||
| "shape", | |||
| [](cg::VarNode* v) -> const TensorShape* { | |||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||
| return mgr.infer_shape_fallible(v); | |||
| }) | |||
| .def_property_readonly( | |||
| "value", | |||
| [](cg::VarNode* v) -> py::object { | |||
| auto&& mgr = v->owner_graph()->static_infer_manager(); | |||
| auto&& type = mgr.get_infer_type(v); | |||
| using InferType = cg::static_infer::InferType; | |||
| if (!(type.value & | |||
| (InferType::CONST | InferType::RT_STATIC))) { | |||
| return py::none(); | |||
| } | |||
| auto* val = mgr.infer_value_fallible(v); | |||
| if (!val) { | |||
| return py::none(); | |||
| } | |||
| return py::cast(*val).attr("numpy")(); | |||
| }) | |||
| .def_property_readonly( | |||
| "id", [](cg::VarNode* v) { return (v->id()); }) | |||
| .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); }); | |||
| py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>( | |||
| m, "OperatorNode") | |||
| @@ -8,6 +8,9 @@ | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/plugin/opr_footprint.h" | |||
| namespace py = pybind11; | |||
| extern py::object Py_Varnode; | |||
| template <typename T> | |||
| class GraphNodePtr { | |||
| std::shared_ptr<mgb::cg::ComputingGraph> m_graph; | |||
| @@ -48,58 +48,11 @@ namespace mgb::imperative::python { | |||
| namespace { | |||
| WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | |||
| struct SymbolVarContext { | |||
| TransformationContext context; | |||
| std::shared_ptr<SymbolTransformation> symbol_tsf; | |||
| std::shared_ptr<ScalarTransformation> scalar_tsf; | |||
| std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf; | |||
| std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf; | |||
| SymbolVarContext(cg::ComputingGraph* graph) { | |||
| symbol_tsf = std::make_shared<SymbolTransformation>(graph); | |||
| scalar_tsf = std::make_shared<ScalarTransformation>(); | |||
| dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>(); | |||
| dim_expansion_tsf = std::make_shared<DimExpansionTransformation>(); | |||
| Transformation::swap_context(context); | |||
| } | |||
| void init() { | |||
| symbol_tsf->register_at(Transformation::top()); | |||
| scalar_tsf->register_at(Transformation::top()); | |||
| dtype_promote_tsf->register_at(Transformation::top()); | |||
| dim_expansion_tsf->register_at(Transformation::top()); | |||
| } | |||
| ValueRef symvar2val(py::handle py_symbol_var) { | |||
| auto* symbol_var = py_symbol_var.cast<PySymbolVar*>(); | |||
| ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node); | |||
| if (symbol_var->is_scalar) { | |||
| value = scalar_tsf->value_type().make(value); | |||
| } | |||
| return value; | |||
| } | |||
| py::object val2symvar(py::handle typeobj, ValueRef value) { | |||
| bool is_scalar = false; | |||
| if (auto* scalar_value = value.as(scalar_tsf->value_type())) { | |||
| value = scalar_value->value(); | |||
| is_scalar = true; | |||
| } | |||
| auto* node = value.cast(symbol_tsf->value_type()).node(); | |||
| auto py_symbol_var = | |||
| typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||
| py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar; | |||
| return py_symbol_var; | |||
| } | |||
| ~SymbolVarContext() { Transformation::swap_context(context); } | |||
| }; | |||
| } // namespace | |||
| interpreter::Interpreter::Channel* interpreter_for_py = nullptr; | |||
| PyTypeObject* py_tensor_type = nullptr; | |||
| PyTypeObject* py_varnode_type = nullptr; | |||
| pybind11::handle py_device_type = nullptr; | |||
| PyObject* cpp_use_symbolic_shape; | |||
| @@ -136,22 +89,6 @@ PyObject* py_apply( | |||
| auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | |||
| SmallVector<ValueRef, 8> tensors(nargs); | |||
| SmallVector<bool, 8> is_symbol_var(nargs, false); | |||
| ComputingGraph* cg = nullptr; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if ((!TensorWrapper::try_cast(args[i])) && | |||
| py::isinstance<PySymbolVar>(py::handle(args[i]))) { | |||
| is_symbol_var[i] = true; | |||
| ComputingGraph* cur_cg = | |||
| py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph(); | |||
| if (cg == nullptr) { | |||
| cg = cur_cg; | |||
| } else { | |||
| mgb_assert(cg == cur_cg); | |||
| } | |||
| } | |||
| } | |||
| mgb::CompNode target_cn; | |||
| mgb::DType target_dtype; | |||
| @@ -174,35 +111,11 @@ PyObject* py_apply( | |||
| } | |||
| }; | |||
| if (cg != nullptr) { | |||
| // swap to a special context to reuse scalar handle | |||
| size_t symbol_var_idx = 8; | |||
| SymbolVarContext context(cg); | |||
| context.init(); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if (is_symbol_var[i]) { | |||
| symbol_var_idx = i; | |||
| tensors[i] = context.symvar2val(args[i]); | |||
| } else if ( | |||
| DTypePromoteCfg::convert_input_enabled && | |||
| op->same_type<Elemwise>()) { | |||
| tensors[i] = convert_pyinput_to_tensor(i); | |||
| } else { | |||
| PyErr_SetString( | |||
| PyExc_TypeError, "py_apply expects tensor as inputs"); | |||
| return nullptr; | |||
| } | |||
| } | |||
| auto outputs = imperative::apply(*op, tensors); | |||
| auto ret = pybind11::tuple(outputs.size()); | |||
| auto typeobj = py::handle(args[symbol_var_idx]).get_type(); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| ret[i] = context.val2symvar(typeobj, outputs[i]); | |||
| } | |||
| return ret.release().ptr(); | |||
| } | |||
| bool is_varnode_apply = false; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if (PyObject_TypeCheck(args[i], py_varnode_type)) { | |||
| is_varnode_apply = true; | |||
| } | |||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
| tensors[i] = tw->m_tensor->data(); | |||
| } else if ( | |||
| @@ -218,8 +131,9 @@ PyObject* py_apply( | |||
| auto outputs = [&] { return imperative::apply(*op, tensors); }(); | |||
| size_t nout = outputs.size(); | |||
| auto ret = py::tuple(nout); | |||
| PyTypeObject* py_type = is_varnode_apply ? py_varnode_type : py_tensor_type; | |||
| for (size_t i = 0; i < nout; ++i) { | |||
| ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i])); | |||
| ret[i] = TensorWrapper::make(py_type, std::move(outputs[i])); | |||
| } | |||
| return ret.release().ptr(); | |||
| } | |||
| @@ -622,9 +536,17 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| CreateTensor::Kind kind = is_const ? CreateTensor::Const | |||
| : no_cache ? CreateTensor::Unique | |||
| : CreateTensor::Common; | |||
| auto&& hval = pyobj2hval(data, cn, dtype); | |||
| auto val = imperative::apply( | |||
| CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0]; | |||
| ValueRef val; | |||
| if (py::isinstance(data, Py_Varnode)) { | |||
| cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>(); | |||
| val = imperative::apply( | |||
| CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0]; | |||
| } else { | |||
| auto&& hval = pyobj2hval(data, cn, dtype); | |||
| val = imperative::apply( | |||
| CreateTensor(kind, cn, hval.dtype, hval.shape), | |||
| hval.storage)[0]; | |||
| } | |||
| m_tensor.emplace(val); | |||
| } | |||
| @@ -734,6 +656,20 @@ PyObject* TensorWrapper::isscalar() { | |||
| } | |||
| } | |||
| PyObject* TensorWrapper::_var() { | |||
| TypedValueRef<NodeValue> value = | |||
| imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>(); | |||
| auto* node = value->node(); | |||
| return py::cast(node).release().ptr(); | |||
| } | |||
| PyObject* TensorWrapper::_graph() { | |||
| TypedValueRef<NodeValue> value = | |||
| imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>(); | |||
| auto* graph = value->graph(); | |||
| return py::cast(graph).release().ptr(); | |||
| } | |||
| struct TensorWeakRef { | |||
| ValueWeakRef data; | |||
| @@ -807,6 +743,10 @@ void init_tensor(py::module m) { | |||
| .register_at<Segment::Scalar>( | |||
| std::make_shared<ScalarTransformation>()) | |||
| .release()); | |||
| MGB_MARK_USED_VAR(transformations | |||
| .register_at<Segment::Symbol>( | |||
| std::make_shared<SymbolTransformation>()) | |||
| .release()); | |||
| MGB_MARK_USED_VAR(transformations | |||
| .register_at<Segment::DTypePromote>( | |||
| std::make_shared<DTypePromoteTransformation>()) | |||
| @@ -863,6 +803,8 @@ void init_tensor(py::module m) { | |||
| .def<&TensorWrapper::_detail>("_detail") | |||
| .def<&TensorWrapper::_set_name>("_set_name") | |||
| .def<&TensorWrapper::_watch>("_watch") | |||
| .def<&TensorWrapper::_var>("var") | |||
| .def<&TensorWrapper::_graph>("graph") | |||
| .def_getset< | |||
| &TensorWrapper::module_trace_info, | |||
| &TensorWrapper::set_module_trace_info>("_NodeMixin__node") | |||
| @@ -875,43 +817,6 @@ void init_tensor(py::module m) { | |||
| .def(py::init<const TensorWrapper&>()) | |||
| .def("__call__", &TensorWeakRef::operator()); | |||
| py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") | |||
| .def_property_readonly( | |||
| "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) | |||
| .def_property( | |||
| "var", [](PySymbolVar* v) { return v->m_node; }, | |||
| [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) | |||
| .def_property_readonly( | |||
| "device", [](PySymbolVar* v) { return v->m_node->comp_node(); }) | |||
| .def_property_readonly( | |||
| "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); }) | |||
| .def_property_readonly( | |||
| "shape", | |||
| [](PySymbolVar* v) -> const TensorShape* { | |||
| auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||
| return mgr.infer_shape_fallible(v->m_node); | |||
| }) | |||
| .def("numpy", | |||
| [](PySymbolVar* v) { | |||
| auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||
| auto&& type = mgr.get_infer_type(v->m_node); | |||
| using InferType = cg::static_infer::InferType; | |||
| if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||
| throw py::value_error("value invalid!"); | |||
| } | |||
| auto* val = mgr.infer_value_fallible(v->m_node); | |||
| if (!val) { | |||
| throw py::value_error("value invalid!"); | |||
| } | |||
| auto np_val = py::cast(*val).attr("numpy")(); | |||
| return np_val; | |||
| }) | |||
| .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||
| .def(py::init([](cg::VarNode* node) { | |||
| return std::make_shared<PySymbolVar>(node); | |||
| }), | |||
| py::arg() = nullptr); | |||
| static PyMethodDef method_defs[] = { | |||
| MGE_PY_INTERFACE(apply, py_apply), | |||
| MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | |||
| @@ -1027,6 +932,10 @@ void init_tensor(py::module m) { | |||
| py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); | |||
| }); | |||
| m.def("set_py_varnode_type", [](py::object type_obj) { | |||
| py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); | |||
| }); | |||
| m.def("set_py_device_type", | |||
| [](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); | |||
| @@ -1217,31 +1126,6 @@ void init_tensor(py::module m) { | |||
| } | |||
| }); | |||
| m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object { | |||
| auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) { | |||
| auto make_scalar_shape = [&](CompNode device) { | |||
| return imperative::apply( | |||
| CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||
| HostStorage::make(device))[0]; | |||
| }; | |||
| return imperative::apply(op, input, make_scalar_shape(*input.device()))[0]; | |||
| }; | |||
| if (py::isinstance<PySymbolVar>(tensor)) { | |||
| auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph(); | |||
| SymbolVarContext context(graph); | |||
| context.init(); | |||
| auto output = reduce_to_scalar( | |||
| *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor)); | |||
| auto typeobj = tensor.get_type(); | |||
| return context.val2symvar(typeobj, output); | |||
| } else { | |||
| auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
| auto output = reduce_to_scalar( | |||
| *op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data()); | |||
| return TensorWrapper::make(py_tensor_type, output); | |||
| } | |||
| }); | |||
| m.def("name_tensor", [](std::string name, py::object tensor) { | |||
| auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
| auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; | |||
| @@ -10,6 +10,8 @@ | |||
| #include "./pyext17.h" | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/transformations/scalar.h" | |||
| #include "megbrain/imperative/transformations/symbol.h" | |||
| #include "megbrain/imperative/utils/span.h" | |||
| namespace mgb::imperative::python { | |||
| @@ -27,6 +29,7 @@ namespace mgb::imperative::python { | |||
| extern interpreter::Interpreter::Channel* interpreter_for_py; | |||
| extern PyTypeObject* py_tensor_type; | |||
| extern PyTypeObject* py_varnode_type; | |||
| extern pybind11::handle py_device_type; | |||
| extern PyObject* cpp_use_symbolic_shape; | |||
| extern PyObject* cpp_astensor1d; | |||
| @@ -126,16 +129,11 @@ public: | |||
| void set_module_trace_info(PyObject*); | |||
| void _set_name(PyObject*); | |||
| PyObject* _detail(); | |||
| PyObject* _var(); | |||
| PyObject* _graph(); | |||
| void _watch(); | |||
| }; | |||
| struct PySymbolVar { | |||
| cg::VarNode* m_node = nullptr; | |||
| bool is_scalar = false; | |||
| PySymbolVar() = default; | |||
| PySymbolVar(VarNode* m) : m_node(m) {} | |||
| }; | |||
| PyObject* py_apply( | |||
| PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); | |||
| @@ -146,15 +146,6 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||
| continue; | |||
| } | |||
| if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
| mgb::DType type = var->m_node->dtype(); | |||
| auto&& descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| continue; | |||
| } | |||
| PyArray_Descr* descr = scalar2dtype(handle); | |||
| if (descr) { | |||
| scalars.emplace_back(descr); | |||
| @@ -204,17 +195,12 @@ CompNode _get_device(PyObject* const* args, size_t nargs) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
| if (tw || is_symvar) { | |||
| if (tw) { | |||
| if (!valid) { | |||
| cn = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||
| cn = tw->m_tensor->comp_node(); | |||
| valid = true; | |||
| } else { | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle) | |||
| .cast<PySymbolVar*>() | |||
| ->m_node->comp_node(); | |||
| CompNode cn1 = tw->m_tensor->comp_node(); | |||
| if (cn1 != cn) { | |||
| throw py::value_error(ssprintf( | |||
| "ambiguous device: %s (from %s) vs %s (from %s)", | |||
| @@ -258,10 +244,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| } | |||
| bool is_scalar(PyObject* tensor) { | |||
| if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | |||
| auto var = py::handle(tensor).cast<PySymbolVar*>(); | |||
| return var->is_scalar; | |||
| } | |||
| auto* tw = TensorWrapper::try_cast(tensor); | |||
| if (tw) { | |||
| return tw->m_tensor->is_scalar(); | |||
| @@ -319,8 +301,7 @@ py::object device2obj(py::handle device, bool mapping = false) { | |||
| } | |||
| } | |||
| py::object _Const( | |||
| py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { | |||
| py::object _Const(py::handle value, py::handle dtype, py::handle device) { | |||
| py::object val = py::reinterpret_borrow<py::object>(value); | |||
| if (PyArray_Check(value.ptr())) { | |||
| py::tuple strides = | |||
| @@ -338,32 +319,6 @@ py::object _Const( | |||
| val = val.attr("reshape")(orig_shp); | |||
| } | |||
| } | |||
| py::object ref; | |||
| if (py::isinstance<py::tuple>(ref_hdl)) { | |||
| py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl); | |||
| if (tup.size()) { | |||
| ref = tup[0]; | |||
| } else { | |||
| ref = py::none(); | |||
| } | |||
| } else { | |||
| ref = py::reinterpret_borrow<py::object>(ref_hdl); | |||
| } | |||
| if (py::isinstance<PySymbolVar>(ref)) { | |||
| auto ref_var = ref.cast<PySymbolVar*>(); | |||
| auto* graph = ref_var->m_node->owner_graph(); | |||
| CompNode cn; | |||
| if (device.ptr() == Py_None) { | |||
| cn = ref_var->m_node->comp_node(); | |||
| } else { | |||
| cn = device2obj(device).cast<CompNode>(); | |||
| } | |||
| OperatorNodeConfig config(cn); | |||
| auto hv = npy::np2tensor( | |||
| val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||
| auto typeobj = ref.get_type(); | |||
| return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||
| } | |||
| py::object device_obj = device2obj(device, true); | |||
| py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | |||
| return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | |||
| @@ -373,7 +328,7 @@ py::tuple _make_shape_tuple(py::handle shape) { | |||
| py::list orig; | |||
| py::list ret(0); | |||
| auto solve_one = [&](py::handle val) { | |||
| if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) { | |||
| if (TensorWrapper::try_cast(val.ptr())) { | |||
| py::object np = getattr(val, "numpy")(); | |||
| PyArrayObject* arr = (PyArrayObject*)np.ptr(); | |||
| PyObject* maybe_list = PyArray_ToList(arr); | |||
| @@ -415,25 +370,53 @@ py::tuple _make_shape_tuple(py::handle shape) { | |||
| return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr())); | |||
| } | |||
| bool is_tensor_or_symbolvar(py::handle arg) { | |||
| return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg); | |||
| bool is_tensor(py::handle arg) { | |||
| return bool(TensorWrapper::try_cast(arg.ptr())); | |||
| } | |||
| bool is_py_sequence(py::handle arg) { | |||
| if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) || | |||
| py::isinstance<PySymbolVar>(arg)) { | |||
| if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr())) { | |||
| return false; | |||
| } | |||
| return PySequence_Check(arg.ptr()); | |||
| } | |||
| mgb::DType _get_dtype(py::handle tensor) { | |||
| if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { | |||
| return tw->m_tensor->dtype(); | |||
| py::object get_res_by_refhdl( | |||
| py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { | |||
| py::object res = _Const(value, dtype, device); | |||
| py::object ref; | |||
| if (py::isinstance<py::tuple>(ref_hdl)) { | |||
| py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl); | |||
| if (tup.size()) { | |||
| ref = tup[0]; | |||
| } else { | |||
| ref = py::none(); | |||
| } | |||
| } else { | |||
| auto var = tensor.cast<PySymbolVar*>(); | |||
| return var->m_node->dtype(); | |||
| ref = py::reinterpret_borrow<py::object>(ref_hdl); | |||
| } | |||
| if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) { | |||
| auto temp = dtype.cast<mgb::DType>(); | |||
| ComputingGraph* graph = getattr(ref, "graph").cast<ComputingGraph*>(); | |||
| cg::VarNode* node = getattr(ref, "var").cast<cg::VarNode*>(); | |||
| CompNode cn; | |||
| if (device.ptr() == Py_None) { | |||
| cn = node->comp_node(); | |||
| } else { | |||
| cn = device2obj(device).cast<CompNode>(); | |||
| } | |||
| OperatorNodeConfig config(cn); | |||
| auto hv = npy::np2tensor( | |||
| value.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||
| auto typeobj = ref.get_type(); | |||
| return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||
| } | |||
| return res; | |||
| } | |||
| mgb::DType _get_dtype(py::handle tensor) { | |||
| auto tw = TensorWrapper::try_cast(tensor.ptr()); | |||
| return tw->m_tensor->dtype(); | |||
| } | |||
| py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | |||
| @@ -457,12 +440,12 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | |||
| py::object _convert_single_value_cpp( | |||
| py::handle value, py::handle dtype, py::handle device) { | |||
| if (is_tensor_or_symbolvar(value)) { | |||
| if (is_tensor(value)) { | |||
| if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { | |||
| return _astype_cpp(value, dtype); | |||
| } | |||
| } else { | |||
| return _Const(value, dtype, device, py::none()); | |||
| return _Const(value, dtype, device); | |||
| } | |||
| return py::reinterpret_borrow<py::object>(value); | |||
| } | |||
| @@ -475,28 +458,8 @@ py::object _convert_inputs_cpp( | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| py::handle h = py::handle(args[i]); | |||
| lis.append(h); | |||
| if (py::isinstance<PySymbolVar>(h)) { | |||
| auto var = h.cast<PySymbolVar*>(); | |||
| auto g = var->m_node->owner_graph(); | |||
| if (!graph) { | |||
| graph = g; | |||
| typeobj = h.get_type(); | |||
| } else { | |||
| mgb_assert(graph == g); | |||
| } | |||
| } | |||
| } | |||
| if (graph) { | |||
| CompNode cn = device2obj(device).cast<CompNode>(); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| OperatorNodeConfig config(cn); | |||
| auto hv = npy::np2tensor( | |||
| lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||
| if (!py::isinstance<PySymbolVar>(lis[i])) { | |||
| lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||
| } | |||
| } | |||
| } | |||
| auto convert = [&](py::object value) { | |||
| if (value.is_none()) { | |||
| return value; | |||
| @@ -517,7 +480,8 @@ py::object _astensor1d_cpp( | |||
| if (device.ptr() != Py_None) { | |||
| device_obj = device2obj(device); | |||
| } | |||
| if (py::isinstance<PySymbolVar>(value)) { | |||
| if (PyObject_TypeCheck(value.ptr(), py_varnode_type)) { | |||
| try { | |||
| getattr(value, "ndim"); | |||
| } catch (py::error_already_set& err) { | |||
| @@ -537,14 +501,15 @@ py::object _astensor1d_cpp( | |||
| return ret; | |||
| } | |||
| } | |||
| size_t ndim = 999; | |||
| if (hasattr(value, "ndim")) { | |||
| ndim = getattr(value, "ndim").cast<size_t>(); | |||
| if (ndim != 0 && ndim != 1) { | |||
| throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim)); | |||
| } | |||
| if (!is_tensor_or_symbolvar(value)) { | |||
| return _Const(value, dtype, device, ref); | |||
| if (!is_tensor(value)) { | |||
| return get_res_by_refhdl(value, dtype, device, ref); | |||
| } else { | |||
| return py::reinterpret_borrow<py::object>(value); | |||
| } | |||
| @@ -555,13 +520,13 @@ py::object _astensor1d_cpp( | |||
| py::list lis = py::reinterpret_steal<py::list>(PySequence_List(value.ptr())); | |||
| bool need_concat = false; | |||
| for (size_t i = 0; i < lis.size(); ++i) { | |||
| if (is_tensor_or_symbolvar(lis[i])) { | |||
| if (is_tensor(lis[i])) { | |||
| need_concat = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!need_concat) { | |||
| return _Const(value, dtype, device, ref); | |||
| return get_res_by_refhdl(value, dtype, device, ref); | |||
| } | |||
| if (lis.size() > 1) { | |||
| std::vector<PyObject*> c_args(lis.size() + 1); | |||
| @@ -600,10 +565,9 @@ py::object _astensor1d_cpp( | |||
| } | |||
| py::object _get_index(py::object tensor, py::object src) { | |||
| if (!TensorWrapper::try_cast(tensor.ptr()) && | |||
| !py::isinstance<PySymbolVar>(tensor)) { | |||
| if (!TensorWrapper::try_cast(tensor.ptr())) { | |||
| auto get_const = [&](mgb::DType dtype) -> py::object { | |||
| return _Const(tensor, py::cast(dtype), src.attr("device"), src); | |||
| return _Const(tensor, py::cast(dtype), src.attr("device")); | |||
| }; | |||
| if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) { | |||
| tensor = get_const(dtype::Bool()); | |||
| @@ -636,9 +600,8 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) { | |||
| } | |||
| py::object iobj; | |||
| if (PyArray_Check(index.ptr())) { | |||
| iobj = | |||
| _Const(index, py::cast((mgb::DType)dtype::Bool()), | |||
| getattr(tensor, "device"), tensor); | |||
| iobj = _Const( | |||
| index, py::cast((mgb::DType)dtype::Bool()), getattr(tensor, "device")); | |||
| } else { | |||
| iobj = py::reinterpret_borrow<py::object>(index); | |||
| } | |||
| @@ -920,8 +883,8 @@ py::object _expand_args(py::handle args) { | |||
| return py::reinterpret_borrow<py::object>(args); | |||
| } | |||
| py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | |||
| if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||
| is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||
| if (args_tup.size() == 1 && | |||
| (PySequence_Check(args_tup[0].ptr()) || is_tensor(args_tup[0].ptr()))) { | |||
| return py::reinterpret_borrow<py::object>(args_tup[0]); | |||
| } else { | |||
| return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||
| @@ -948,7 +911,8 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { | |||
| bool enable_fastpath(py::handle inp) { | |||
| auto&& tm_tr = TransformationManager::get_instance() | |||
| .segments[TransformationManager::Segment::ModuleTrace]; | |||
| if (!TensorWrapper::try_cast(inp.ptr()) || | |||
| bool is_varnode = PyObject_TypeCheck(inp.ptr(), py_varnode_type); | |||
| if (is_varnode || | |||
| TransformationManager::get_instance() | |||
| .segments[TransformationManager::Segment::Trace] | |||
| .size() > 0 || | |||
| @@ -1181,10 +1145,8 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { | |||
| py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) { | |||
| py::object org_shape = getattr(inp_hdl, "shape"); | |||
| py::object val = py::reinterpret_borrow<py::object>(val_hdl); | |||
| if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) { | |||
| val = | |||
| _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"), | |||
| inp_hdl); | |||
| if (!TensorWrapper::try_cast(val.ptr())) { | |||
| val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device")); | |||
| } | |||
| py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); | |||
| @@ -1308,12 +1270,12 @@ py::object _split_cpp( | |||
| repr(nsplits_or_sections_hdl).cast<std::string>()); | |||
| } | |||
| py::object pos = div_points[i] - div_points[i - 1]; | |||
| if (is_tensor_or_symbolvar(pos)) { | |||
| if (is_tensor(pos)) { | |||
| partitions.append(pos); | |||
| } else { | |||
| partitions.append( | |||
| _Const(pos, py::cast((mgb::DType)dtype::Int32()), | |||
| getattr(inp_hdl, "device"), inp_hdl)); | |||
| getattr(inp_hdl, "device"))); | |||
| } | |||
| } | |||
| op = Split::make(axis, 0); | |||
| @@ -1438,7 +1400,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { | |||
| py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object obj = _expand_args(args); | |||
| py::list lis; | |||
| if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { | |||
| if (!is_tensor(obj.ptr()) && PySequence_Check(obj.ptr())) { | |||
| lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr())); | |||
| } else { | |||
| py::object np = getattr(obj, "numpy")(); | |||
| @@ -1631,7 +1593,7 @@ PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) | |||
| PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _Const(args[0], args[1], args[2], args[3]).release().ptr(); | |||
| return _Const(args[0], args[1], args[2]).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| @@ -1696,4 +1658,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| } // namespace mgb::imperative::python | |||
| @@ -20,11 +20,12 @@ public: | |||
| DimExpansion, | |||
| Grad, | |||
| Scalar, | |||
| Symbol, | |||
| Trace, | |||
| Eval, | |||
| }; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments; | |||
| std::array<std::vector<std::shared_ptr<Transformation>>, 8> segments; | |||
| private: | |||
| template <Segment segment> | |||
| @@ -11,7 +11,7 @@ from megengine.utils.network_node import VarNode | |||
| def _default_compare_fn(x, y): | |||
| if isinstance(x, tensor): | |||
| if isinstance(x, tensor) and not isinstance(x, VarNode): | |||
| x = x.numpy() | |||
| elif not isinstance(x, np.ndarray): | |||
| x = get_var_value(x) | |||
| @@ -679,6 +679,18 @@ def test_utils_astensor1d(is_varnode): | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | |||
| # varnode | |||
| if is_varnode: | |||
| a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32") | |||
| b = np.array([[True, False, True], [False, True, True]]) | |||
| aa = make_tensor(a, network) | |||
| bb = make_tensor(b, network) | |||
| x, y = F.cond_take(bb, aa) | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(get_var_value(xx), get_var_value(x)) | |||
| def test_device(): | |||
| x = tensor([1, 2, 3], dtype="float32") | |||
| @@ -114,8 +114,10 @@ def test_replace_opr(): | |||
| vara = graph.var_filter.name("a").as_unique() | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| out1 = F.sub(vara, varb) | |||
| out1 = F.mul(vara, varb) | |||
| out1 = F.relu(out1) | |||
| out1 += 2 | |||
| out1 *= 3 | |||
| out1 = graph.add_dep_oprs(out1) | |||
| orig_opr = graph.opr_filter.has_input(vara).as_unique() | |||
| @@ -135,7 +137,7 @@ def test_replace_opr(): | |||
| load_graph = GraphInference(modified_model1) | |||
| out = load_graph.run(a, b) | |||
| np.testing.assert_equal(out["o"], [0, 0]) | |||
| np.testing.assert_equal(out["o"], [30, 60]) | |||
| def test_splice_network(): | |||
| @@ -82,6 +82,10 @@ std::string DTRCommand::to_string() const { | |||
| return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind); | |||
| } | |||
| std::string CreateNode::to_string() const { | |||
| return "CreateNode"; | |||
| } | |||
| std::string GetName::to_string() const { | |||
| return "GetName{}"; | |||
| } | |||
| @@ -94,5 +98,9 @@ std::string IsScalar::to_string() const { | |||
| return "IsScalar"; | |||
| } | |||
| std::string GetVarVal::to_string() const { | |||
| return "GetVarVal"; | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -157,5 +157,22 @@ public: | |||
| std::string to_string() const override; | |||
| }; | |||
| class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { | |||
| public: | |||
| std::string to_string() const override; | |||
| }; | |||
| class CreateNode final : public OperatorImpl<CreateNode> { | |||
| private: | |||
| cg::VarNode* m_node; | |||
| public: | |||
| CreateNode(cg::VarNode* node) : m_node(node) {} | |||
| cg::VarNode* node() const { return m_node; } | |||
| std::string to_string() const override; | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -173,5 +173,24 @@ public: | |||
| std::string to_string() const override; | |||
| }; | |||
| class NodeStorage { | |||
| private: | |||
| cg::VarNode* m_node; | |||
| public: | |||
| NodeStorage() = default; | |||
| NodeStorage(VarNode* node) : m_node(node) {} | |||
| VarNode* node() const { return m_node; } | |||
| ComputingGraph* graph() const { return m_node->owner_graph(); } | |||
| std::string to_string() const { return m_node->name(); } | |||
| }; | |||
| class NodeValue final : public PrimitiveValue<NodeValue, NodeStorage> { | |||
| public: | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override { return NodeStorage::to_string(); } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -39,13 +39,39 @@ private: | |||
| ObjectType<SymbolValue> m_value_type{"SymbolValue"}; | |||
| public: | |||
| SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} | |||
| SymbolTransformation() {} | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override { | |||
| ComputingGraph* cg = nullptr; | |||
| if (auto* node_value = op.as<CreateNode>()) { | |||
| return {m_value_type.make(node_value->node())}; | |||
| } | |||
| for (auto&& input : inputs) { | |||
| if (auto* val = input.as(m_value_type)) { | |||
| auto* node = val->node(); | |||
| ComputingGraph* cur_cg = node->owner_graph(); | |||
| if (cg == nullptr) { | |||
| cg = cur_cg; | |||
| } else { | |||
| mgb_assert(cg == cur_cg, "input varnode gragh should be the same"); | |||
| } | |||
| } | |||
| } | |||
| if (!cg) { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||
| SmallVector<VarNode*> input_nodes; | |||
| for (auto&& input : inputs) { | |||
| input_nodes.push_back(input.cast(m_value_type).node()); | |||
| if (!input.is(m_value_type)) { | |||
| auto* node = opr::ImmutableTensor::make( | |||
| *cg, input.numpy()->as_nd(true), {}) | |||
| .node(); | |||
| input_nodes.push_back(node); | |||
| } else { | |||
| input_nodes.push_back(input.cast(m_value_type).node()); | |||
| } | |||
| } | |||
| auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | |||
| ValueRefList outputs(output_nodes.size()); | |||
| @@ -53,15 +79,9 @@ public: | |||
| outputs[i] = m_value_type.make(output_nodes[i]); | |||
| } | |||
| return outputs; | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| auto&& args = create_tensor->parse(inputs); | |||
| mgb_assert( | |||
| args.kind == CreateTensor::Const, | |||
| "only const value is allowed here"); | |||
| auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); | |||
| return {m_value_type.make(node)}; | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| auto* node = inputs.item().cast(m_value_type).node(); | |||
| auto* m_graph = node->owner_graph(); | |||
| switch (get_attr->attr()) { | |||
| case GetAttr::DType: | |||
| return {DTypeValue::make(node->dtype())}; | |||
| @@ -105,6 +125,10 @@ public: | |||
| MegBrainError, "Symbol: malformed GetAttr: %s", | |||
| op.to_string().c_str()); | |||
| } | |||
| } else if (auto* get_attr = op.as<GetVarVal>()) { | |||
| cg::VarNode* node = inputs.item().cast(m_value_type).node(); | |||
| NodeStorage inp_var = NodeStorage(node); | |||
| return {NodeValue::make(inp_var)}; | |||
| } else { | |||
| return op.fallback(inputs); | |||
| } | |||
| @@ -33,6 +33,7 @@ class ShapeValue; | |||
| class DTypeValue; | |||
| class CompNodeValue; | |||
| class StringValue; | |||
| class NodeValue; | |||
| class Operator; | |||