| @@ -8,6 +8,9 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from .._imperative_rt import make_const | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor | |||
| class Const: | |||
| def __init__(self, value=None, *, dtype=None, device=None): | |||
| @@ -19,7 +22,19 @@ class Const: | |||
| from ...tensor import Tensor | |||
| device = self.device | |||
| if device is None: | |||
| device = reference[0].device | |||
| if len(reference) != 0: | |||
| reference = reference[0] | |||
| assert isinstance( | |||
| reference, (SymbolVar, Tensor) | |||
| ), "Reference should be Tensor or VarNode" | |||
| if device is None: | |||
| device = reference.device | |||
| if isinstance(reference, SymbolVar): | |||
| cls = type(reference) | |||
| rst = cls(make_const(reference.graph, self.value, device, self.dtype)) | |||
| return (rst,) | |||
| return (Tensor(self.value, self.dtype, self.device, True),) | |||
| @@ -13,7 +13,7 @@ from typing import Union | |||
| import numpy as np | |||
| from .._imperative_rt.common import CompNode | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
| from ..ops import builtin | |||
| from ..ops.builtin import Elemwise, GetVarShape | |||
| from . import utils | |||
| @@ -230,7 +230,9 @@ def _todo(*_): | |||
| def _expand_args(args): | |||
| if len(args) == 1: | |||
| if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||
| if isinstance( | |||
| args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), | |||
| ): | |||
| args = args[0] | |||
| return args | |||
| @@ -10,7 +10,7 @@ from typing import Iterable | |||
| import numpy as np | |||
| from .._imperative_rt.core2 import Tensor, apply | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
| from .._trace_option import use_symbolic_shape | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| @@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| return True | |||
| def get_index(i): | |||
| if not isinstance(i, (Tensor)): | |||
| if not isinstance(i, (Tensor, SymbolVar)): | |||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
| else: | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)() | |||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
| return i | |||
| assert isinstance(i, Tensor) | |||
| assert isinstance(i, (Tensor, SymbolVar)) | |||
| if i.dtype != np.bool_: | |||
| return i | |||
| _, ind = apply(builtin.CondTake(), i, i) | |||
| @@ -197,9 +197,9 @@ def try_condtake(tensor, index): | |||
| ): | |||
| return [] | |||
| if isinstance(index, np.ndarray): | |||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||
| assert isinstance(index, Tensor) | |||
| if not isinstance(tensor, Tensor): | |||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||
| assert isinstance(index, (Tensor, SymbolVar)) | |||
| if not isinstance(tensor, (Tensor, SymbolVar)): | |||
| raise TypeError("input must be a tensor") | |||
| if tensor.device != index.device: | |||
| raise ValueError( | |||
| @@ -214,11 +214,16 @@ def getitem(tensor, index): | |||
| return try_result[0] | |||
| tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||
| for v in tensors: | |||
| if v.shape is None: | |||
| break | |||
| if isinstance(v.shape, v.__class__): | |||
| break | |||
| if len(v.shape) > 0 and v.shape[0] == 0: | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
| tensor | |||
| ) | |||
| return empty_tensor | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| else: | |||
| @@ -235,8 +240,8 @@ def setitem(tensor, index, value): | |||
| if len(try_result) == 2: | |||
| index = try_result[1] | |||
| tensor = tensor.reshape(-1) | |||
| if not isinstance(value, Tensor): | |||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||
| if not isinstance(value, (Tensor, SymbolVar)): | |||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | |||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
| if use_subtensor: | |||
| op = builtin.Subtensor(items=items) | |||
| @@ -11,8 +11,9 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from .._imperative_rt import VarNode | |||
| from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||
| from .._imperative_rt import make_const | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | |||
| from .._wrap import device as as_device | |||
| from ..ops import builtin | |||
| from ..ops.special import Const | |||
| from .dtype import is_dtype_equal, is_quantize | |||
| @@ -38,13 +39,9 @@ def set_convert_inputs(flag): | |||
| def concatenate(inputs, axis=0, *, device=None): | |||
| dtype = dtype_promotion(inputs) | |||
| device = get_device(inputs) | |||
| def convert(x): | |||
| return convert_single_value(x, dtype=dtype, device=device) | |||
| inputs = tuple(map(convert, inputs)) | |||
| inputs = convert_inputs(*inputs) | |||
| if device is None: | |||
| device = get_device(inputs) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | |||
| return result | |||
| @@ -60,7 +57,7 @@ def astype(x, dtype): | |||
| def convert_single_value(v, *, dtype=None, device=None): | |||
| if isinstance(v, (Tensor, VarNode)): | |||
| if isinstance(v, (Tensor, SymbolVar)): | |||
| if not is_quantize(v.dtype): | |||
| v = astype(v, dtype) | |||
| else: | |||
| @@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None): | |||
| return v | |||
| def convert_inputs(*args: Tensor): | |||
| def convert_inputs(*args, device=None): | |||
| if not _enable_convert_inputs: | |||
| return args | |||
| dtype = dtype_promotion(args) | |||
| device = get_device(args) | |||
| if device is None: | |||
| device = get_device(args) | |||
| device = as_device(device) | |||
| graph = None | |||
| sym_type = None | |||
| for a in args: | |||
| if isinstance(a, SymbolVar): | |||
| if graph is None: | |||
| graph = a.var.graph | |||
| sym_type = type(a) | |||
| else: | |||
| assert graph == a.var.graph | |||
| args = list(args) | |||
| if graph is not None: | |||
| for i in range(len(args)): | |||
| if not isinstance(args[i], SymbolVar): | |||
| rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||
| args[i] = sym_type(rst) | |||
| def convert(value): | |||
| if value is None: | |||
| return value | |||
| return convert_single_value(value, dtype=dtype, device=device) | |||
| return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||
| return tuple(map(convert, args)) | |||
| @@ -98,14 +113,14 @@ def result_type(*args): | |||
| def isscalar(x): | |||
| if isinstance(x, Tensor): | |||
| if isinstance(x, (Tensor, SymbolVar)): | |||
| return x._isscalar() | |||
| return np.isscalar(x) | |||
| def setscalar(x): | |||
| if isinstance(x, Tensor): | |||
| if isinstance(x, (Tensor, SymbolVar)): | |||
| x._setscalar() | |||
| else: | |||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
| @@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| if not isinstance(x, collections.abc.Sequence): | |||
| raise TypeError | |||
| if any(isinstance(i, Tensor) for i in x): | |||
| if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||
| x = concatenate(x, device=device) | |||
| if dtype is not None: | |||
| x = astype(x, dtype) | |||
| @@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| def _expand_int(s, i): | |||
| if isinstance(i, Tensor): | |||
| if isinstance(i, (Tensor, SymbolVar)): | |||
| i_np = i.numpy() | |||
| if i_np.ndim == 0: | |||
| s.append(int(i_np)) | |||
| @@ -9,8 +9,7 @@ | |||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.graph import VarNode | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Elemwise | |||
| from ..core.tensor import utils | |||
| @@ -72,7 +71,7 @@ __all__ = [ | |||
| def _elwise(*args, mode): | |||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||
| if len(tensor_args) == 0: | |||
| dtype = utils.dtype_promotion(args) | |||
| first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
| @@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Copy, Identity | |||
| @@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten | |||
| return result | |||
| def full(shape, value, dtype="float32", device=None): | |||
| def full(shape, value, dtype="float32", device=None) -> Tensor: | |||
| """ | |||
| Returns a tensor with given shape and value. | |||
| """ | |||
| @@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None): | |||
| return broadcast_to(x, shape) | |||
| def ones(shape, dtype="float32", device=None): | |||
| def ones(shape, dtype="float32", device=None) -> Tensor: | |||
| """ | |||
| Returns a ones tensor with given shape. | |||
| @@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None): | |||
| return full(shape, 1.0, dtype=dtype, device=device) | |||
| def zeros(shape, dtype="float32", device=None): | |||
| def zeros(shape, dtype="float32", device=None) -> Tensor: | |||
| """ | |||
| Returns a zero tensor with given shape. | |||
| """ | |||
| return full(shape, 0.0, dtype=dtype, device=device) | |||
| def zeros_like(inp: Tensor) -> Tensor: | |||
| def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| """ | |||
| Returns a zero tensor with the same shape as input tensor. | |||
| @@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor: | |||
| [0 0 0]] | |||
| """ | |||
| return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||
| return full_like(inp, 0.0) | |||
| def ones_like(inp: Tensor) -> Tensor: | |||
| def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
| """ | |||
| Returns a ones tensor with the same shape as input tensor. | |||
| """ | |||
| return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||
| return full_like(inp, 1.0) | |||
| def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||
| def full_like( | |||
| inp: Union[Tensor, SymbolVar], value: Union[int, float] | |||
| ) -> Union[Tensor, SymbolVar]: | |||
| """ | |||
| Returns a tensor filled with given value with the same shape as input tensor. | |||
| """ | |||
| return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
| if inp.shape is (): | |||
| return x | |||
| return broadcast_to(x, inp.shape) | |||
| def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
| @@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
| if len(inps) == 1: | |||
| return inps[0] | |||
| dtype = dtype_promotion(inps) | |||
| inps = convert_inputs(*inps, device=device) | |||
| if device is None: | |||
| device = get_device(inps) | |||
| device = as_device(device) | |||
| def convert(x): | |||
| return convert_single_value(x, dtype=dtype, device=device) | |||
| inps = tuple(map(convert, inps)) | |||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||
| return result | |||
| @@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0): | |||
| Ntotal, axis, Nsections | |||
| ) | |||
| ) | |||
| func = ( | |||
| floor_div | |||
| if isinstance(Nsections, (SymbolVar, Tensor)) | |||
| else lambda x, y: x // y | |||
| ) | |||
| div_points = [0] + [ | |||
| floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||
| func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||
| ] | |||
| for i in range(2, Nsections + 1): | |||
| div_points[i] = div_points[i - 1] + div_points[i] | |||
| @@ -925,11 +931,15 @@ def linspace( | |||
| if not (cur_device is None or device == cur_device): | |||
| raise ("ambiguous device for linspace opr") | |||
| if not isinstance(start, Tensor): | |||
| is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) | |||
| if any(is_symbolvar) and not all(is_symbolvar): | |||
| raise TypeError("start, stop and num should all be VarNode or none of them") | |||
| if not isinstance(start, (Tensor, SymbolVar)): | |||
| start = Tensor(start, device=device) | |||
| if not isinstance(stop, Tensor): | |||
| if not isinstance(stop, (Tensor, SymbolVar)): | |||
| stop = Tensor(stop, device=device) | |||
| if not isinstance(num, Tensor): | |||
| if not isinstance(num, (Tensor, SymbolVar)): | |||
| num = Tensor(num, device=device) | |||
| op = builtin.Linspace(comp_node=device) | |||
| @@ -983,7 +993,7 @@ def arange( | |||
| stop = stop.astype("float32") | |||
| if isinstance(step, Tensor): | |||
| step = step.astype("float32") | |||
| num = ceil(Tensor((stop - start) / step, device=device)) | |||
| num = ceil((stop - start) / step) | |||
| stop = start + step * (num - 1) | |||
| result = linspace(start, stop, num, device=device) | |||
| if np.dtype(dtype) == np.int32: | |||
| @@ -16,6 +16,7 @@ from typing import Dict, List | |||
| import numpy as np | |||
| from ..core._imperative_rt import ComputingGraph | |||
| from ..core._imperative_rt.core2 import SymbolVar | |||
| from ..core.tensor import megbrain_graph as G | |||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
| from .network_node import ( | |||
| @@ -60,12 +61,12 @@ class Network: | |||
| ) | |||
| outputs = [new_outputs[i] for i in outspec] | |||
| self._orig_outputs = outputs | |||
| self.add_dep_oprs(*outputs) | |||
| for x in self._orig_outputs: | |||
| self.output_vars.append(self._get_var(x)) | |||
| self.add_dep_oprs() | |||
| for x in self._orig_inputs: | |||
| self.input_vars.append(self._get_var(x)) | |||
| for x in self._orig_outputs: | |||
| self.output_vars.append(self._get_var(x)) | |||
| self.graph = self._orig_outputs[0].graph | |||
| return self | |||
| @@ -197,6 +198,8 @@ class Network: | |||
| def add_output(self, *vars: VarNode): | |||
| """Adds vars into the network output node list | |||
| """ | |||
| if not all([var.owner for var in vars]): | |||
| self.add_dep_oprs(*vars) | |||
| for var in vars: | |||
| if var not in self.output_vars: | |||
| self.output_vars.append(var) | |||
| @@ -209,21 +212,25 @@ class Network: | |||
| self.output_vars.remove(var) | |||
| def add_dep_oprs(self, *vars): | |||
| """Adds dependent opnodes and varnodes of vars into network | |||
| """ | |||
| oprs = get_oprs_seq(vars, False, False) | |||
| for mge_opr in oprs: | |||
| if len(vars) == 0: | |||
| vars = self.output_vars | |||
| q = list(vars) | |||
| while len(q) > 0: | |||
| cur = q.pop(0) | |||
| if cur.owner is not None: | |||
| continue | |||
| if cur.name is None: | |||
| cur.name = cur.var.name | |||
| self.all_vars_map[cur.var.id] = cur | |||
| mge_opr = cur.var.owner | |||
| if get_opr_type(mge_opr) == "Host2DeviceCopy": | |||
| self._orig_inputs.extend(mge_opr.outputs) | |||
| opr = self._add_opr(mge_opr) | |||
| if opr is not None: | |||
| for x in mge_opr.inputs: | |||
| opr.add_inp_var(self._get_var(x)) | |||
| # set out var | |||
| for x in mge_opr.outputs: | |||
| opr.add_out_var(self._get_var(x)) | |||
| return [self.all_vars_map[var.id] for var in vars] | |||
| cur.owner = self._add_opr(mge_opr) | |||
| if cur.owner is None: | |||
| cur.owner = self.all_oprs_map[mge_opr.id] | |||
| continue | |||
| q.extend(cur.owner.inputs) | |||
| return list(vars) | |||
| def modify_opr_names(self, modifier): | |||
| """Modifies names of operators **inplace**; useful for merging loaded | |||
| @@ -275,6 +282,9 @@ class Network: | |||
| Replaces vars in the graph. | |||
| :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||
| """ | |||
| if not all([var.owner for var in repl_dict.values()]): | |||
| print(repl_dict.values()) | |||
| self.add_dep_oprs(*list(repl_dict.values())) | |||
| for var in self.all_vars: | |||
| if var in repl_dict: | |||
| repl_var = repl_dict[var] | |||
| @@ -282,6 +292,7 @@ class Network: | |||
| idx = owner.outputs.index(repl_var) | |||
| owner.outputs[idx] = var | |||
| var.__dict__.update(repl_var.__dict__) | |||
| var.var = repl_var.var | |||
| def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
| """ | |||
| @@ -297,6 +308,7 @@ class Network: | |||
| for ind, var in enumerate(opr.outputs): | |||
| var.owner = repl_dict[opr] | |||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
| var.var = repl_dict[opr].outputs[ind].var | |||
| def get_opr_by_type(self, oprcls, unique=True): | |||
| assert issubclass(oprcls, OpNode) | |||
| @@ -381,11 +393,16 @@ class Network: | |||
| return self.opr_filter.as_dict() | |||
| # used for loading and building graph | |||
| def _add_opr(self, x): | |||
| def _add_opr(self, opr): | |||
| # TODO: use megbrain C++ RTTI to replace type string | |||
| if x.id not in self.all_oprs_map: | |||
| self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||
| return self.all_oprs_map[x.id] | |||
| if opr.id not in self.all_oprs_map: | |||
| opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | |||
| self.all_oprs_map[opr.id] = opnode | |||
| for var in opr.inputs: | |||
| opnode.add_inp_var(self._get_var(var)) | |||
| for var in opr.outputs: | |||
| opnode.add_out_var(self._get_var(var)) | |||
| return opnode | |||
| else: | |||
| return None | |||
| @@ -397,7 +414,7 @@ class Network: | |||
| def _get_var(self, x): | |||
| # auto convert to VarNode of Network | |||
| if x.id not in self.all_vars_map: | |||
| if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | |||
| self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | |||
| return self.all_vars_map[x.id] | |||
| @@ -652,7 +669,7 @@ class NodeFilterHasInput(NodeFilter): | |||
| assert isinstance( | |||
| i, OpNode | |||
| ), "has_input() must be used with OpNode; " "got {!r}".format(i) | |||
| if self.var in i.inputs: | |||
| if any(self.var is _ for _ in i.inputs): | |||
| yield i | |||
| @@ -6,16 +6,21 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import abc | |||
| import json | |||
| import sys | |||
| from typing import Callable | |||
| from typing import Callable, Sequence | |||
| import numpy as np | |||
| from ..core import _imperative_rt as rt | |||
| from ..core._imperative_rt.core2 import SymbolVar | |||
| from ..core._wrap import Device | |||
| from ..core.ops import builtin | |||
| from ..core.tensor.megbrain_graph import InputNode | |||
| from ..core.tensor.array_method import ArrayMethodMixin | |||
| from ..core.tensor.indexing import getitem as _getitem | |||
| from ..core.tensor.indexing import setitem as _setitem | |||
| from ..core.tensor.megbrain_graph import InputNode, OutputNode | |||
| from ..tensor import Tensor | |||
| from .comp_graph_tools import replace_vars | |||
| from .module_stats import ( | |||
| @@ -29,9 +34,13 @@ class NetworkNode: | |||
| pass | |||
| class VarNode(NetworkNode): | |||
| def __init__(self, owner_opr=None, name=None): | |||
| self.var = None | |||
| class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||
| pass | |||
| class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| def __init__(self, var=None, *, owner_opr=None, name=None): | |||
| SymbolVar.__init__(self, var) | |||
| self.owner = owner_opr | |||
| self.name = name | |||
| self.id = id(self) | |||
| @@ -58,6 +67,40 @@ class VarNode(NetworkNode): | |||
| def dtype(self): | |||
| return self.var.dtype if self.var else None | |||
| def __bool__(self): | |||
| return False | |||
| __index__ = None | |||
| __int__ = None | |||
| __float__ = None | |||
| __complex__ = None | |||
| def __hash__(self): | |||
| return id(self) | |||
| @property | |||
| def _tuple_shape(self): | |||
| return self.var.shape | |||
| def numpy(self): | |||
| o = OutputNode(self.var) | |||
| self.graph.compile(o.outputs).execute() | |||
| return o.get_value().numpy() | |||
| def __getitem__(self, index): | |||
| return _getitem(self, index) | |||
| def __setitem__(self, index, value): | |||
| if index is not Ellipsis: | |||
| value = _setitem(self, index, value) | |||
| if self.owner is not None: | |||
| idx = self.owner.outputs.index(self) | |||
| self.owner.outputs[idx] = VarNode( | |||
| self.var, owner_opr=self.owner, name=self.var.name | |||
| ) | |||
| self.var = value.var | |||
| self.owner = None | |||
| def set_owner_opr(self, owner_opr): | |||
| self.owner = owner_opr | |||
| @@ -135,7 +178,7 @@ class Host2DeviceCopy(OpNode): | |||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
| self._opr = outputs.owner | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(self, self.name)) | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = outputs | |||
| assert self.outputs[0].owner is self | |||
| @@ -173,8 +216,8 @@ class ImmutableTensor(OpNode): | |||
| def set_value(self, data, device=None): | |||
| assert self.graph is not None | |||
| cn = device if device else self.device | |||
| assert isinstance(data, (int, float, np.ndarray)) | |||
| if isinstance(data, (int, float)): | |||
| assert isinstance(data, (int, float, Sequence, np.ndarray)) | |||
| if not isinstance(data, np.ndarray): | |||
| data = np.array(data) | |||
| if data.dtype == np.float64: | |||
| data = data.astype(np.float32) | |||
| @@ -182,7 +225,7 @@ class ImmutableTensor(OpNode): | |||
| data = data.astype(np.int32) | |||
| varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(self, self.name)) | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = varnode | |||
| self._opr = varnode.owner | |||
| @@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
| if (ctx.op->same_type<BackwardGraph>()) { | |||
| ctx.backward = true; | |||
| } | |||
| if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||
| SmallVector<cg::VarNode*> vinputs(nargs); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||
| } | |||
| auto op = ctx.op.get(); | |||
| return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||
| } | |||
| if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | |||
| SmallVector<cg::VarNode*> vinputs(nargs); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||
| } | |||
| auto op = ctx.op.get(); | |||
| auto rst = OpDef::apply_on_var_node(*op, vinputs); | |||
| auto ret = pybind11::tuple(rst.size()); | |||
| auto typeobj = py::handle(args[0]).get_type(); | |||
| for (size_t i = 0; i<rst.size(); ++i) { | |||
| ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic)); | |||
| } | |||
| return ret.release().ptr(); | |||
| } | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
| @@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
| continue; | |||
| } | |||
| if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||
| auto var = py::handle(handle).cast<cg::VarNode *>(); | |||
| mgb::DType type = var->dtype(); | |||
| if (py::isinstance<PySymbolVar>(py::handle(handle))){ | |||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
| mgb::DType type = var->m_node->dtype(); | |||
| auto && descr = npy::dtype_mgb2np_descr(type); | |||
| Py_INCREF(descr.get()); | |||
| tensors.emplace_back(descr.get()); | |||
| @@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
| bool valid = false; | |||
| CompNode cn; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
| bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||
| if (tw || is_var) { | |||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
| if (tw || is_symvar) { | |||
| if (!valid) { | |||
| cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
| cn = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle) | |||
| .cast<PySymbolVar*>() | |||
| ->m_node->comp_node(); | |||
| valid = true; | |||
| } else { | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
| : py::handle(handle) | |||
| .cast<PySymbolVar*>() | |||
| ->m_node->comp_node(); | |||
| if (cn1 != cn) { | |||
| throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||
| cn.to_string().c_str(), cn1.to_string().c_str())); | |||
| cn.to_string().c_str(), | |||
| cn1.to_string().c_str())); | |||
| } | |||
| } | |||
| } | |||
| @@ -849,6 +861,32 @@ void init_tensor(py::module m) { | |||
| .def("__call__", &TensorWeakRef::operator()) | |||
| .def("_use_cnt", &TensorWeakRef::_use_cnt); | |||
| py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") | |||
| .def_property_readonly( | |||
| "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) | |||
| .def_property("var", [](PySymbolVar* v) { return v->m_node; }, | |||
| [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) | |||
| .def_property_readonly( | |||
| "device", | |||
| [](PySymbolVar* v) { return v->m_node->comp_node(); }) | |||
| .def_property_readonly( | |||
| "graph", | |||
| [](PySymbolVar* v) { return v->m_node->owner_graph(); }) | |||
| .def_property_readonly( | |||
| "shape", | |||
| [](PySymbolVar* v) -> const TensorShape* { | |||
| auto&& mgr = v->m_node->owner_graph() | |||
| ->static_infer_manager(); | |||
| return mgr.infer_shape_fallible(v->m_node); | |||
| }) | |||
| .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||
| .def("_setscalar", | |||
| [](PySymbolVar* v) { return v->is_scalar = true; }) | |||
| .def(py::init([](cg::VarNode* node) { | |||
| return std::make_shared<PySymbolVar>(node); | |||
| }), | |||
| py::arg() = nullptr); | |||
| static PyMethodDef method_defs[] = { | |||
| MGE_PY_INTERFACE(apply, py_apply), | |||
| MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | |||
| @@ -181,6 +181,12 @@ struct TensorWrapper { | |||
| PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | |||
| }; | |||
| struct PySymbolVar { | |||
| cg::VarNode* m_node = nullptr; | |||
| bool is_scalar = false; | |||
| PySymbolVar() = default; | |||
| PySymbolVar(VarNode *m): m_node(m){} | |||
| }; | |||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | |||
| @@ -2,9 +2,11 @@ import io | |||
| import numpy as np | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine import tensor | |||
| from megengine.jit import trace | |||
| from megengine.utils.network_node import VarNode | |||
| def _default_compare_fn(x, y): | |||
| @@ -14,8 +16,23 @@ def _default_compare_fn(x, y): | |||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
| def make_tensor(x, network=None, device=None): | |||
| if network is not None: | |||
| if isinstance(x, VarNode): | |||
| return VarNode(x.var) | |||
| return network.make_const(x, device=device) | |||
| else: | |||
| return tensor(x, device=device) | |||
| def opr_test( | |||
| cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||
| cases, | |||
| func, | |||
| compare_fn=_default_compare_fn, | |||
| ref_fn=None, | |||
| test_trace=True, | |||
| network=None, | |||
| **kwargs | |||
| ): | |||
| """ | |||
| :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||
| @@ -44,7 +61,7 @@ def opr_test( | |||
| if not isinstance(results, (tuple, list)): | |||
| results = (results,) | |||
| for r, e in zip(results, expected): | |||
| if not isinstance(r, tensor): | |||
| if not isinstance(r, (tensor, VarNode)): | |||
| r = tensor(r) | |||
| compare_fn(r, e) | |||
| @@ -72,9 +89,9 @@ def opr_test( | |||
| raise ValueError("the input func should be callable") | |||
| inp, outp = get_param(cases, 0) | |||
| inp_tensor = [tensor(inpi) for inpi in inp] | |||
| inp_tensor = [make_tensor(inpi, network) for inpi in inp] | |||
| if test_trace: | |||
| if test_trace and not network: | |||
| copied_inp = inp_tensor.copy() | |||
| for symbolic in [False, True]: | |||
| traced_func = trace(symbolic=symbolic)(func) | |||
| @@ -10,12 +10,17 @@ import collections | |||
| import numpy as np | |||
| import pytest | |||
| from utils import make_tensor | |||
| import megengine | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.functional as F | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.ops import builtin | |||
| from megengine.tensor import Tensor | |||
| from megengine.utils.network import Network | |||
| from megengine.utils.network_node import VarNode | |||
| def cvt_to_shape_desc(val, inpvar, config=None): | |||
| @@ -387,108 +392,130 @@ def test_batched_mesh_indexing(): | |||
| # high level | |||
| def get_value(x): | |||
| if isinstance(x, VarNode): | |||
| var = x.var | |||
| o = G.OutputNode(var) | |||
| graph = x.graph | |||
| graph.compile(o.outputs).execute() | |||
| return o.get_value().numpy() | |||
| else: | |||
| return x.numpy() | |||
| @pytest.mark.parametrize("test_varnode", [True, False]) | |||
| def test_advance_indexing_high_level(test_varnode): | |||
| if test_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_advance_indexing_high_level(): | |||
| x = np.arange(25).reshape(5, 5).astype("int32") | |||
| d = np.arange(15).reshape(3, 5).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| np.testing.assert_equal(x[1, :], xx[1, :].numpy()) | |||
| np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||
| np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy()) | |||
| np.testing.assert_equal(x[1, :], get_value(xx[1, :])) | |||
| np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||
| np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) | |||
| np.testing.assert_equal(x[:, :], xx[:, :].numpy()) | |||
| np.testing.assert_equal(x[1, 1], xx[1, 1].numpy()) | |||
| np.testing.assert_equal(x[:, :], get_value(xx[:, :])) | |||
| np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) | |||
| yy = xx[(0, 4, 2), :] | |||
| np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy()) | |||
| np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy)) | |||
| x_ = x.copy() | |||
| x_[(0, 4, 2), :] = d | |||
| xx_ = Tensor(xx) | |||
| xx_ = make_tensor(xx, network) | |||
| xx_[(0, 4, 2), :] = d | |||
| np.testing.assert_equal(x_, xx_.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx_)) | |||
| x = np.arange(27).reshape(3, 3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy()) | |||
| np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy()) | |||
| np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy()) | |||
| np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy()) | |||
| np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy()) | |||
| np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||
| np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy()) | |||
| np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :])) | |||
| np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1])) | |||
| np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :])) | |||
| np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1])) | |||
| np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1])) | |||
| np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||
| np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2])) | |||
| x_ = x.copy() | |||
| x_[1, 1, 1] = -1 | |||
| xx[1, 1, 1] = -1 | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[:, 1, 1] = -2 | |||
| xx[:, 1, 1] = x_[:, 1, 1] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[0:1, :, 1] = -3 | |||
| xx[0:1, :, 1] = x_[0:1, :, 1] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[0:1, :, 1] = -4 | |||
| y = Tensor(x_) | |||
| y = make_tensor(x_, network) | |||
| xx[0:1, :, 1] = y[0:1, :, 1] | |||
| np.testing.assert_equal(y.numpy(), xx.numpy()) | |||
| np.testing.assert_equal(get_value(y), get_value(xx)) | |||
| x[:] = 1 | |||
| xx[:] = 1 | |||
| np.testing.assert_equal(x, xx.numpy()) | |||
| np.testing.assert_equal(x, get_value(xx)) | |||
| x = np.arange(9).reshape(3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| y = np.array([1, 2]) | |||
| yy = Tensor(y) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
| yy = make_tensor(y, network) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||
| x_ = x.copy() | |||
| x_[:, y[0]] = -1 | |||
| xx_ = Tensor(x_) | |||
| xx_ = make_tensor(x_, network) | |||
| xx[:, yy[0]] = xx_[:, yy[0]] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x_[:, y] = -1 | |||
| xx_ = Tensor(x_) | |||
| xx_ = make_tensor(x_, network) | |||
| xx[:, yy] = xx_[:, yy] | |||
| np.testing.assert_equal(x_, xx.numpy()) | |||
| np.testing.assert_equal(x_, get_value(xx)) | |||
| x = np.arange(9).reshape(3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| xx = make_tensor(x, network) | |||
| y = np.array([1]) | |||
| yy = Tensor(y) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
| yy = make_tensor(y, network) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
| np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||
| x = np.arange(9).reshape(3, 3).astype("int32") | |||
| xx = Tensor(x) | |||
| np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy()) | |||
| np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy()) | |||
| def test_advance_indexing_with_bool(): | |||
| xx = make_tensor(x, network) | |||
| np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0])) | |||
| np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0])) | |||
| @pytest.mark.parametrize( | |||
| "test_varnode", [True, False], | |||
| ) | |||
| def test_advance_indexing_with_bool(test_varnode): | |||
| if test_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| a = np.arange(9).reshape(3, 3).astype(np.float32) | |||
| b = np.array([1, 2, 3]) | |||
| c = np.array([1, 2, 3]) | |||
| aa = Tensor(a) | |||
| bb = Tensor(b) | |||
| cc = Tensor(c) | |||
| np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy()) | |||
| aa = make_tensor(a, network) | |||
| bb = make_tensor(b, network) | |||
| cc = make_tensor(c, network) | |||
| np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2])) | |||
| a[b == 1, c == 2] = -1.0 | |||
| aa[bb == 1, cc == 2] = -1.0 | |||
| np.testing.assert_equal(a, aa.numpy()) | |||
| np.testing.assert_equal(a, get_value(aa)) | |||
| a = np.arange(9).reshape(3, 3).astype(np.float32) | |||
| b = np.array([False, True, True]) | |||
| @@ -11,13 +11,16 @@ import platform | |||
| import numpy as np | |||
| import pytest | |||
| from utils import opr_test | |||
| from utils import make_tensor, opr_test | |||
| import megengine.functional as F | |||
| from megengine import tensor | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.tensor import megbrain_graph as G | |||
| from megengine.core.tensor.utils import astensor1d | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.utils.network import Network | |||
| from megengine.utils.network_node import VarNode | |||
| def test_eye(): | |||
| @@ -38,7 +41,13 @@ def test_eye(): | |||
| ) | |||
| def test_concat(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_concat(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def get_data_shape(length: int): | |||
| return (length, 2, 3) | |||
| @@ -50,18 +59,30 @@ def test_concat(): | |||
| return F.concat([data1, data2]) | |||
| cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | |||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) | |||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | |||
| def test_concat_device(): | |||
| data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0") | |||
| data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_concat_device(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") | |||
| data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") | |||
| out = F.concat([data1, data2], device="cpu0") | |||
| assert str(out.device).split(":")[0] == "cpu0" | |||
| def test_stack(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_stack(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data1 = np.random.random((3, 2, 2)).astype("float32") | |||
| data2 = np.random.random((3, 2, 2)).astype("float32") | |||
| data3 = np.random.random((3, 2, 2)).astype("float32") | |||
| @@ -72,12 +93,20 @@ def test_stack(): | |||
| def run(data1, data2): | |||
| return F.stack([data1, data2], axis=ai) | |||
| opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai)) | |||
| opr_test( | |||
| cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network | |||
| ) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_split(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_split(): | |||
| data = np.random.random((2, 3, 4, 5)).astype(np.float32) | |||
| inp = tensor(data) | |||
| inp = make_tensor(data, network) | |||
| mge_out0 = F.split(inp, 2, axis=3) | |||
| mge_out1 = F.split(inp, [3], axis=3) | |||
| @@ -106,26 +135,42 @@ def test_split(): | |||
| assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | |||
| def test_reshape(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_reshape(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = np.arange(6, dtype="float32") | |||
| xx = tensor(x) | |||
| xx = make_tensor(x, network) | |||
| y = x.reshape(1, 2, 3) | |||
| for shape in [ | |||
| (1, 2, 3), | |||
| (1, -1, 3), | |||
| (1, tensor(-1), 3), | |||
| (1, make_tensor(-1, network), 3), | |||
| np.array([1, -1, 3], dtype="int32"), | |||
| tensor([1, -1, 3]), | |||
| make_tensor([1, -1, 3], network), | |||
| ]: | |||
| yy = F.reshape(xx, shape) | |||
| np.testing.assert_equal(yy.numpy(), y) | |||
| def test_reshape_shape_inference(): | |||
| x_shape_known = tensor([1, 2, 3, 4], dtype="float32") | |||
| x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) | |||
| tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_reshape_shape_inference(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x_shape_known = make_tensor([1, 2, 3, 4], network) | |||
| x_shape_unknown = F.broadcast_to( | |||
| make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum() | |||
| ) | |||
| tshp_unknown = astensor1d( | |||
| (make_tensor([2], network), make_tensor([2], network)), x_shape_known | |||
| ) | |||
| tshp_known = astensor1d((2, 2), x_shape_known) | |||
| tshp_known_unspec = astensor1d((2, -1), x_shape_known) | |||
| @@ -146,12 +191,18 @@ def test_reshape_shape_inference(): | |||
| {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, | |||
| {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | |||
| ] | |||
| opr_test(cases, func, compare_fn=check_shape, test_trace=True) | |||
| opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_squeeze(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_squeeze(): | |||
| x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | |||
| xx = tensor(x) | |||
| xx = make_tensor(x, network) | |||
| for axis in [None, 3, -4, (3, -4)]: | |||
| y = np.squeeze(x, axis) | |||
| @@ -159,9 +210,15 @@ def test_squeeze(): | |||
| np.testing.assert_equal(y, yy.numpy()) | |||
| def test_expand_dims(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_expand_dims(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = np.arange(6, dtype="float32").reshape(2, 3) | |||
| xx = tensor(x) | |||
| xx = make_tensor(x, network) | |||
| for axis in [2, -3, (3, -4), (1, -4)]: | |||
| y = np.expand_dims(x, axis) | |||
| @@ -169,11 +226,17 @@ def test_expand_dims(): | |||
| np.testing.assert_equal(y, yy.numpy()) | |||
| def test_elemwise_dtype_promotion(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_elemwise_dtype_promotion(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = np.random.rand(2, 3).astype("float32") | |||
| y = np.random.rand(1, 3).astype("float16") | |||
| xx = tensor(x) | |||
| yy = tensor(y) | |||
| xx = make_tensor(x, network) | |||
| yy = make_tensor(y, network) | |||
| z = xx * yy | |||
| np.testing.assert_equal(z.numpy(), x * y) | |||
| @@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion(): | |||
| np.testing.assert_equal(z.numpy(), x - y) | |||
| def test_linspace(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_linspace(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| cases = [ | |||
| {"input": [1, 9, 9]}, | |||
| {"input": [3, 10, 8]}, | |||
| @@ -193,6 +262,7 @@ def test_linspace(): | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| @@ -203,20 +273,28 @@ def test_linspace(): | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| {"input": [1, tensor(9), 9]}, | |||
| {"input": [tensor(1), 9, tensor(9)]}, | |||
| {"input": [1, make_tensor(9, network), 9]}, | |||
| {"input": [make_tensor(1, network), 9, make_tensor(9, network)]}, | |||
| ] | |||
| opr_test( | |||
| cases, | |||
| F.linspace, | |||
| ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| def test_arange(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_arange(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| cases = [ | |||
| {"input": [1, 9, 1]}, | |||
| {"input": [2, 10, 2]}, | |||
| @@ -225,6 +303,7 @@ def test_arange(): | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| @@ -235,6 +314,7 @@ def test_arange(): | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| cases = [ | |||
| @@ -245,20 +325,33 @@ def test_arange(): | |||
| cases, | |||
| F.arange, | |||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
| network=network, | |||
| ) | |||
| def test_round(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_round(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data1_shape = (15,) | |||
| data2_shape = (25,) | |||
| data1 = np.random.random(data1_shape).astype(np.float32) | |||
| data2 = np.random.random(data2_shape).astype(np.float32) | |||
| cases = [{"input": data1}, {"input": data2}] | |||
| opr_test(cases, F.round, ref_fn=np.round) | |||
| opr_test(cases, F.round, ref_fn=np.round, network=network) | |||
| def test_flatten(): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_flatten(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| data0_shape = (2, 3, 4, 5) | |||
| data1_shape = (4, 5, 6, 7) | |||
| data0 = np.random.random(data0_shape).astype(np.float32) | |||
| @@ -273,7 +366,7 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn) | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) | |||
| output0 = (2, 3 * 4 * 5) | |||
| output1 = (4, 5 * 6 * 7) | |||
| @@ -281,7 +374,7 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) | |||
| output0 = (2, 3, 4 * 5) | |||
| output1 = (4, 5, 6 * 7) | |||
| @@ -289,7 +382,7 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) | |||
| output0 = (2, 3 * 4, 5) | |||
| output1 = (4, 5 * 6, 7) | |||
| @@ -297,10 +390,23 @@ def test_flatten(): | |||
| {"input": data0, "output": output0}, | |||
| {"input": data1, "output": output1}, | |||
| ] | |||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) | |||
| opr_test( | |||
| cases, | |||
| F.flatten, | |||
| compare_fn=compare_fn, | |||
| start_axis=1, | |||
| end_axis=2, | |||
| network=network, | |||
| ) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_broadcast(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def test_broadcast(): | |||
| input1_shape = (20, 30) | |||
| output1_shape = (30, 20, 30) | |||
| data1 = np.random.random(input1_shape).astype(np.float32) | |||
| @@ -321,7 +427,7 @@ def test_broadcast(): | |||
| {"input": [data2, output2_shape], "output": output2_shape}, | |||
| {"input": [data3, output3_shape], "output": output3_shape}, | |||
| ] | |||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) | |||
| x = F.ones((2, 1, 3)) | |||
| with pytest.raises(RuntimeError): | |||
| @@ -334,35 +440,41 @@ def test_broadcast(): | |||
| F.broadcast_to(x, (1, 3)) | |||
| def test_utils_astensor1d(): | |||
| reference = tensor(0) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_utils_astensor1d(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| reference = make_tensor(0, network) | |||
| # literal | |||
| x = [1, 2, 3] | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), x) | |||
| # numpy array | |||
| x = np.asarray([1, 2, 3], dtype="int32") | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) | |||
| # tensor | |||
| x = tensor([1, 2, 3], dtype="int32") | |||
| x = make_tensor([1, 2, 3], network) | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), x.numpy()) | |||
| # mixed | |||
| x = [1, tensor(2), 3] | |||
| x = [1, make_tensor(2, network), 3] | |||
| for dtype in [None, "float32"]: | |||
| xx = astensor1d(x, reference, dtype=dtype) | |||
| assert type(xx) is tensor | |||
| assert isinstance(xx, type(reference)) | |||
| np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | |||
| @@ -382,35 +494,60 @@ def test_device(): | |||
| np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) | |||
| def test_identity(): | |||
| x = tensor(np.random.random((5, 10)).astype(np.float32)) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_identity(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| x = make_tensor(np.random.random((5, 10)).astype(np.float32), network) | |||
| y = F.copy(x) | |||
| np.testing.assert_equal(y.numpy(), x) | |||
| def copy_test(dst, src): | |||
| def copy_test(dst, src, network): | |||
| data = np.random.random((2, 3)).astype(np.float32) | |||
| x = tensor(data, device=src) | |||
| x = make_tensor(data, device=src, network=network) | |||
| y = F.copy(x, dst) | |||
| assert np.allclose(data, y.numpy()) | |||
| z = x.to(dst) | |||
| assert np.allclose(data, z.numpy()) | |||
| if network is None: | |||
| z = x.to(dst) | |||
| assert np.allclose(data, z.numpy()) | |||
| @pytest.mark.require_ngpu(1) | |||
| def test_copy_h2d(): | |||
| copy_test("cpu0", "gpu0") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_copy_h2d(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| copy_test("cpu0", "gpu0", network=network) | |||
| @pytest.mark.require_ngpu(1) | |||
| def test_copy_d2h(): | |||
| copy_test("gpu0", "cpu0") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_copy_d2h(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| copy_test("gpu0", "cpu0", network=network) | |||
| @pytest.mark.require_ngpu(2) | |||
| def test_copy_d2d(): | |||
| copy_test("gpu0", "gpu1") | |||
| copy_test("gpu0:0", "gpu0:1") | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_copy_d2d(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| copy_test("gpu0", "gpu1", network=network) | |||
| copy_test("gpu0:0", "gpu0:1", network=network) | |||
| @pytest.mark.parametrize( | |||
| @@ -425,7 +562,13 @@ def test_copy_d2d(): | |||
| ((), 10, None), | |||
| ], | |||
| ) | |||
| def test_repeat(shape, repeats, axis): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_repeat(shape, repeats, axis, is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def repeat_func(inp): | |||
| return F.repeat(inp=inp, repeats=repeats, axis=axis) | |||
| @@ -437,7 +580,10 @@ def test_repeat(shape, repeats, axis): | |||
| cases = [{"input": np.array(1.23)}] | |||
| opr_test( | |||
| cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||
| cases, | |||
| repeat_func, | |||
| ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||
| network=network, | |||
| ) | |||
| @@ -450,14 +596,16 @@ def test_repeat(shape, repeats, axis): | |||
| ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||
| ], | |||
| ) | |||
| def test_tile(shape, reps): | |||
| @pytest.mark.parametrize("is_varnode", [True]) | |||
| def test_tile(shape, reps, is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| def tile_func(inp): | |||
| return F.tile(inp=inp, reps=reps) | |||
| cases = [ | |||
| {"input": np.random.randn(*shape).astype("float32")}, | |||
| ] | |||
| cases = [{"input": np.random.randn(*shape).astype("float32")}] | |||
| opr_test( | |||
| cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||
| ) | |||
| opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network) | |||
| @@ -34,13 +34,11 @@ def test_replace_var(): | |||
| vara = graph.var_filter.name("a").as_unique() | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| out = F.mul(vara.var, varb.var) | |||
| out = F.mul(vara, varb) | |||
| out = F.relu(out) | |||
| var_list = graph.add_dep_oprs(out) | |||
| opnode = list(graph.opr_filter.has_input(vara)) | |||
| repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||
| repl_dict = {opnode[0].outputs[0]: out} | |||
| graph.replace_vars(repl_dict) | |||
| modified_model = io.BytesIO() | |||
| @@ -72,14 +70,12 @@ def test_replace_opr(): | |||
| vara = graph.var_filter.name("a").as_unique() | |||
| varb = graph.var_filter.name("b").as_unique() | |||
| out1 = F.sub(vara.var, varb.var) | |||
| out1 = F.sub(vara, varb) | |||
| out1 = F.relu(out1) | |||
| var_list = graph.add_dep_oprs(out1) | |||
| repl_opr = as_oprnode(var_list) | |||
| out1 = graph.add_dep_oprs(out1) | |||
| orig_opr = graph.opr_filter.has_input(vara).as_unique() | |||
| repl_dict = {orig_opr: repl_opr} | |||
| repl_dict = {orig_opr: out1[0].owner} | |||
| graph.replace_oprs(repl_dict) | |||
| modified_model1 = io.BytesIO() | |||
| graph.dump(modified_model1) | |||
| @@ -171,8 +167,7 @@ def test_add_input(): | |||
| inp_c = graph.make_input_node((2,), np.int32, name="c") | |||
| varo = graph.var_filter.name("o").as_unique() | |||
| out = F.add(varo.var, inp_c.var) | |||
| out = graph.add_dep_oprs(out)[0] | |||
| out = F.add(varo, inp_c) | |||
| out.name = "o1" | |||
| graph.remove_output(varo) | |||
| graph.add_output(out) | |||
| @@ -206,12 +201,11 @@ def test_add_output(): | |||
| var_a = net.var_filter.name("a").as_unique() | |||
| var_b = net.var_filter.name("b").as_unique() | |||
| y = F.add(var_a.var, var_b.var) | |||
| y = F.add(var_a, var_b) | |||
| y = F.sigmoid(y) | |||
| new_vars = net.add_dep_oprs(y)[0] | |||
| new_vars.name = "o1" | |||
| net.add_output(new_vars) | |||
| y.name = "o1" | |||
| net.add_output(y) | |||
| modified_model = io.BytesIO() | |||
| net.dump(modified_model) | |||