GitOrigin-RevId: 6e4d05b475
tags/v1.4.0-rc1
| @@ -8,6 +8,9 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt import make_const | |||||
| from .._imperative_rt.core2 import SymbolVar, Tensor | |||||
| class Const: | class Const: | ||||
| def __init__(self, value=None, *, dtype=None, device=None): | def __init__(self, value=None, *, dtype=None, device=None): | ||||
| @@ -19,7 +22,19 @@ class Const: | |||||
| from ...tensor import Tensor | from ...tensor import Tensor | ||||
| device = self.device | device = self.device | ||||
| if device is None: | |||||
| device = reference[0].device | |||||
| if len(reference) != 0: | |||||
| reference = reference[0] | |||||
| assert isinstance( | |||||
| reference, (SymbolVar, Tensor) | |||||
| ), "Reference should be Tensor or VarNode" | |||||
| if device is None: | |||||
| device = reference.device | |||||
| if isinstance(reference, SymbolVar): | |||||
| cls = type(reference) | |||||
| rst = cls(make_const(reference.graph, self.value, device, self.dtype)) | |||||
| return (rst,) | |||||
| return (Tensor(self.value, self.dtype, self.device, True),) | return (Tensor(self.value, self.dtype, self.device, True),) | ||||
| @@ -13,7 +13,7 @@ from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt.common import CompNode | from .._imperative_rt.common import CompNode | ||||
| from .._imperative_rt.core2 import Tensor, apply | |||||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.builtin import Elemwise, GetVarShape | from ..ops.builtin import Elemwise, GetVarShape | ||||
| from . import utils | from . import utils | ||||
| @@ -230,7 +230,9 @@ def _todo(*_): | |||||
| def _expand_args(args): | def _expand_args(args): | ||||
| if len(args) == 1: | if len(args) == 1: | ||||
| if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||||
| if isinstance( | |||||
| args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), | |||||
| ): | |||||
| args = args[0] | args = args[0] | ||||
| return args | return args | ||||
| @@ -10,7 +10,7 @@ from typing import Iterable | |||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt.core2 import Tensor, apply | |||||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||||
| from .._trace_option import use_symbolic_shape | from .._trace_option import use_symbolic_shape | ||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| @@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| return True | return True | ||||
| def get_index(i): | def get_index(i): | ||||
| if not isinstance(i, (Tensor)): | |||||
| if not isinstance(i, (Tensor, SymbolVar)): | |||||
| if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | ||||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||||
| (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||||
| else: | else: | ||||
| (i,) = Const(i, dtype=np.int32, device=inp.device)() | |||||
| (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||||
| return i | return i | ||||
| assert isinstance(i, Tensor) | |||||
| assert isinstance(i, (Tensor, SymbolVar)) | |||||
| if i.dtype != np.bool_: | if i.dtype != np.bool_: | ||||
| return i | return i | ||||
| _, ind = apply(builtin.CondTake(), i, i) | _, ind = apply(builtin.CondTake(), i, i) | ||||
| @@ -197,9 +197,9 @@ def try_condtake(tensor, index): | |||||
| ): | ): | ||||
| return [] | return [] | ||||
| if isinstance(index, np.ndarray): | if isinstance(index, np.ndarray): | ||||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||||
| assert isinstance(index, Tensor) | |||||
| if not isinstance(tensor, Tensor): | |||||
| (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||||
| assert isinstance(index, (Tensor, SymbolVar)) | |||||
| if not isinstance(tensor, (Tensor, SymbolVar)): | |||||
| raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
| if tensor.device != index.device: | if tensor.device != index.device: | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -214,11 +214,16 @@ def getitem(tensor, index): | |||||
| return try_result[0] | return try_result[0] | ||||
| tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | ||||
| for v in tensors: | for v in tensors: | ||||
| if v.shape is None: | |||||
| break | |||||
| if isinstance(v.shape, v.__class__): | if isinstance(v.shape, v.__class__): | ||||
| break | break | ||||
| if len(v.shape) > 0 and v.shape[0] == 0: | if len(v.shape) > 0 and v.shape[0] == 0: | ||||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||||
| (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||||
| tensor | |||||
| ) | |||||
| return empty_tensor | return empty_tensor | ||||
| if use_subtensor: | if use_subtensor: | ||||
| op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
| else: | else: | ||||
| @@ -235,8 +240,8 @@ def setitem(tensor, index, value): | |||||
| if len(try_result) == 2: | if len(try_result) == 2: | ||||
| index = try_result[1] | index = try_result[1] | ||||
| tensor = tensor.reshape(-1) | tensor = tensor.reshape(-1) | ||||
| if not isinstance(value, Tensor): | |||||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||||
| if not isinstance(value, (Tensor, SymbolVar)): | |||||
| (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | |||||
| tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | ||||
| if use_subtensor: | if use_subtensor: | ||||
| op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
| @@ -11,8 +11,9 @@ from typing import Iterable, Union | |||||
| import numpy as np | import numpy as np | ||||
| from .._imperative_rt import VarNode | |||||
| from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||||
| from .._imperative_rt import make_const | |||||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | |||||
| from .._wrap import device as as_device | |||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | from ..ops.special import Const | ||||
| from .dtype import is_dtype_equal, is_quantize | from .dtype import is_dtype_equal, is_quantize | ||||
| @@ -38,13 +39,9 @@ def set_convert_inputs(flag): | |||||
| def concatenate(inputs, axis=0, *, device=None): | def concatenate(inputs, axis=0, *, device=None): | ||||
| dtype = dtype_promotion(inputs) | |||||
| device = get_device(inputs) | |||||
| def convert(x): | |||||
| return convert_single_value(x, dtype=dtype, device=device) | |||||
| inputs = tuple(map(convert, inputs)) | |||||
| inputs = convert_inputs(*inputs) | |||||
| if device is None: | |||||
| device = get_device(inputs) | |||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | ||||
| return result | return result | ||||
| @@ -60,7 +57,7 @@ def astype(x, dtype): | |||||
| def convert_single_value(v, *, dtype=None, device=None): | def convert_single_value(v, *, dtype=None, device=None): | ||||
| if isinstance(v, (Tensor, VarNode)): | |||||
| if isinstance(v, (Tensor, SymbolVar)): | |||||
| if not is_quantize(v.dtype): | if not is_quantize(v.dtype): | ||||
| v = astype(v, dtype) | v = astype(v, dtype) | ||||
| else: | else: | ||||
| @@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None): | |||||
| return v | return v | ||||
| def convert_inputs(*args: Tensor): | |||||
| def convert_inputs(*args, device=None): | |||||
| if not _enable_convert_inputs: | if not _enable_convert_inputs: | ||||
| return args | return args | ||||
| dtype = dtype_promotion(args) | dtype = dtype_promotion(args) | ||||
| device = get_device(args) | |||||
| if device is None: | |||||
| device = get_device(args) | |||||
| device = as_device(device) | |||||
| graph = None | |||||
| sym_type = None | |||||
| for a in args: | |||||
| if isinstance(a, SymbolVar): | |||||
| if graph is None: | |||||
| graph = a.var.graph | |||||
| sym_type = type(a) | |||||
| else: | |||||
| assert graph == a.var.graph | |||||
| args = list(args) | |||||
| if graph is not None: | |||||
| for i in range(len(args)): | |||||
| if not isinstance(args[i], SymbolVar): | |||||
| rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||||
| args[i] = sym_type(rst) | |||||
| def convert(value): | def convert(value): | ||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| return convert_single_value(value, dtype=dtype, device=device) | |||||
| return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||||
| return tuple(map(convert, args)) | return tuple(map(convert, args)) | ||||
| @@ -98,14 +113,14 @@ def result_type(*args): | |||||
| def isscalar(x): | def isscalar(x): | ||||
| if isinstance(x, Tensor): | |||||
| if isinstance(x, (Tensor, SymbolVar)): | |||||
| return x._isscalar() | return x._isscalar() | ||||
| return np.isscalar(x) | return np.isscalar(x) | ||||
| def setscalar(x): | def setscalar(x): | ||||
| if isinstance(x, Tensor): | |||||
| if isinstance(x, (Tensor, SymbolVar)): | |||||
| x._setscalar() | x._setscalar() | ||||
| else: | else: | ||||
| raise NotImplementedError("Unsupport type {}".format(type(x))) | raise NotImplementedError("Unsupport type {}".format(type(x))) | ||||
| @@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| if not isinstance(x, collections.abc.Sequence): | if not isinstance(x, collections.abc.Sequence): | ||||
| raise TypeError | raise TypeError | ||||
| if any(isinstance(i, Tensor) for i in x): | |||||
| if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||||
| x = concatenate(x, device=device) | x = concatenate(x, device=device) | ||||
| if dtype is not None: | if dtype is not None: | ||||
| x = astype(x, dtype) | x = astype(x, dtype) | ||||
| @@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| def _expand_int(s, i): | def _expand_int(s, i): | ||||
| if isinstance(i, Tensor): | |||||
| if isinstance(i, (Tensor, SymbolVar)): | |||||
| i_np = i.numpy() | i_np = i.numpy() | ||||
| if i_np.ndim == 0: | if i_np.ndim == 0: | ||||
| s.append(int(i_np)) | s.append(int(i_np)) | ||||
| @@ -9,8 +9,7 @@ | |||||
| # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | ||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core._imperative_rt.graph import VarNode | |||||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Elemwise | from ..core.ops.builtin import Elemwise | ||||
| from ..core.tensor import utils | from ..core.tensor import utils | ||||
| @@ -72,7 +71,7 @@ __all__ = [ | |||||
| def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||||
| tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||||
| if len(tensor_args) == 0: | if len(tensor_args) == 0: | ||||
| dtype = utils.dtype_promotion(args) | dtype = utils.dtype_promotion(args) | ||||
| first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | ||||
| @@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
| @@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten | |||||
| return result | return result | ||||
| def full(shape, value, dtype="float32", device=None): | |||||
| def full(shape, value, dtype="float32", device=None) -> Tensor: | |||||
| """ | """ | ||||
| Returns a tensor with given shape and value. | Returns a tensor with given shape and value. | ||||
| """ | """ | ||||
| @@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None): | |||||
| return broadcast_to(x, shape) | return broadcast_to(x, shape) | ||||
| def ones(shape, dtype="float32", device=None): | |||||
| def ones(shape, dtype="float32", device=None) -> Tensor: | |||||
| """ | """ | ||||
| Returns a ones tensor with given shape. | Returns a ones tensor with given shape. | ||||
| @@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None): | |||||
| return full(shape, 1.0, dtype=dtype, device=device) | return full(shape, 1.0, dtype=dtype, device=device) | ||||
| def zeros(shape, dtype="float32", device=None): | |||||
| def zeros(shape, dtype="float32", device=None) -> Tensor: | |||||
| """ | """ | ||||
| Returns a zero tensor with given shape. | Returns a zero tensor with given shape. | ||||
| """ | """ | ||||
| return full(shape, 0.0, dtype=dtype, device=device) | return full(shape, 0.0, dtype=dtype, device=device) | ||||
| def zeros_like(inp: Tensor) -> Tensor: | |||||
| def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
| """ | """ | ||||
| Returns a zero tensor with the same shape as input tensor. | Returns a zero tensor with the same shape as input tensor. | ||||
| @@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor: | |||||
| [0 0 0]] | [0 0 0]] | ||||
| """ | """ | ||||
| return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||||
| return full_like(inp, 0.0) | |||||
| def ones_like(inp: Tensor) -> Tensor: | |||||
| def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||||
| """ | """ | ||||
| Returns a ones tensor with the same shape as input tensor. | Returns a ones tensor with the same shape as input tensor. | ||||
| """ | """ | ||||
| return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||||
| return full_like(inp, 1.0) | |||||
| def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||||
| def full_like( | |||||
| inp: Union[Tensor, SymbolVar], value: Union[int, float] | |||||
| ) -> Union[Tensor, SymbolVar]: | |||||
| """ | """ | ||||
| Returns a tensor filled with given value with the same shape as input tensor. | Returns a tensor filled with given value with the same shape as input tensor. | ||||
| """ | """ | ||||
| return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||||
| if inp.shape is (): | |||||
| return x | |||||
| return broadcast_to(x, inp.shape) | |||||
| def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | ||||
| @@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||||
| if len(inps) == 1: | if len(inps) == 1: | ||||
| return inps[0] | return inps[0] | ||||
| dtype = dtype_promotion(inps) | |||||
| inps = convert_inputs(*inps, device=device) | |||||
| if device is None: | if device is None: | ||||
| device = get_device(inps) | device = get_device(inps) | ||||
| device = as_device(device) | device = as_device(device) | ||||
| def convert(x): | |||||
| return convert_single_value(x, dtype=dtype, device=device) | |||||
| inps = tuple(map(convert, inps)) | |||||
| (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | ||||
| return result | return result | ||||
| @@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
| Ntotal, axis, Nsections | Ntotal, axis, Nsections | ||||
| ) | ) | ||||
| ) | ) | ||||
| func = ( | |||||
| floor_div | |||||
| if isinstance(Nsections, (SymbolVar, Tensor)) | |||||
| else lambda x, y: x // y | |||||
| ) | |||||
| div_points = [0] + [ | div_points = [0] + [ | ||||
| floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||||
| func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||||
| ] | ] | ||||
| for i in range(2, Nsections + 1): | for i in range(2, Nsections + 1): | ||||
| div_points[i] = div_points[i - 1] + div_points[i] | div_points[i] = div_points[i - 1] + div_points[i] | ||||
| @@ -925,11 +931,15 @@ def linspace( | |||||
| if not (cur_device is None or device == cur_device): | if not (cur_device is None or device == cur_device): | ||||
| raise ("ambiguous device for linspace opr") | raise ("ambiguous device for linspace opr") | ||||
| if not isinstance(start, Tensor): | |||||
| is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) | |||||
| if any(is_symbolvar) and not all(is_symbolvar): | |||||
| raise TypeError("start, stop and num should all be VarNode or none of them") | |||||
| if not isinstance(start, (Tensor, SymbolVar)): | |||||
| start = Tensor(start, device=device) | start = Tensor(start, device=device) | ||||
| if not isinstance(stop, Tensor): | |||||
| if not isinstance(stop, (Tensor, SymbolVar)): | |||||
| stop = Tensor(stop, device=device) | stop = Tensor(stop, device=device) | ||||
| if not isinstance(num, Tensor): | |||||
| if not isinstance(num, (Tensor, SymbolVar)): | |||||
| num = Tensor(num, device=device) | num = Tensor(num, device=device) | ||||
| op = builtin.Linspace(comp_node=device) | op = builtin.Linspace(comp_node=device) | ||||
| @@ -983,7 +993,7 @@ def arange( | |||||
| stop = stop.astype("float32") | stop = stop.astype("float32") | ||||
| if isinstance(step, Tensor): | if isinstance(step, Tensor): | ||||
| step = step.astype("float32") | step = step.astype("float32") | ||||
| num = ceil(Tensor((stop - start) / step, device=device)) | |||||
| num = ceil((stop - start) / step) | |||||
| stop = start + step * (num - 1) | stop = start + step * (num - 1) | ||||
| result = linspace(start, stop, num, device=device) | result = linspace(start, stop, num, device=device) | ||||
| if np.dtype(dtype) == np.int32: | if np.dtype(dtype) == np.int32: | ||||
| @@ -16,6 +16,7 @@ from typing import Dict, List | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt import ComputingGraph | from ..core._imperative_rt import ComputingGraph | ||||
| from ..core._imperative_rt.core2 import SymbolVar | |||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | ||||
| from .network_node import ( | from .network_node import ( | ||||
| @@ -60,12 +61,12 @@ class Network: | |||||
| ) | ) | ||||
| outputs = [new_outputs[i] for i in outspec] | outputs = [new_outputs[i] for i in outspec] | ||||
| self._orig_outputs = outputs | self._orig_outputs = outputs | ||||
| self.add_dep_oprs(*outputs) | |||||
| for x in self._orig_outputs: | |||||
| self.output_vars.append(self._get_var(x)) | |||||
| self.add_dep_oprs() | |||||
| for x in self._orig_inputs: | for x in self._orig_inputs: | ||||
| self.input_vars.append(self._get_var(x)) | self.input_vars.append(self._get_var(x)) | ||||
| for x in self._orig_outputs: | |||||
| self.output_vars.append(self._get_var(x)) | |||||
| self.graph = self._orig_outputs[0].graph | self.graph = self._orig_outputs[0].graph | ||||
| return self | return self | ||||
| @@ -197,6 +198,8 @@ class Network: | |||||
| def add_output(self, *vars: VarNode): | def add_output(self, *vars: VarNode): | ||||
| """Adds vars into the network output node list | """Adds vars into the network output node list | ||||
| """ | """ | ||||
| if not all([var.owner for var in vars]): | |||||
| self.add_dep_oprs(*vars) | |||||
| for var in vars: | for var in vars: | ||||
| if var not in self.output_vars: | if var not in self.output_vars: | ||||
| self.output_vars.append(var) | self.output_vars.append(var) | ||||
| @@ -209,21 +212,25 @@ class Network: | |||||
| self.output_vars.remove(var) | self.output_vars.remove(var) | ||||
| def add_dep_oprs(self, *vars): | def add_dep_oprs(self, *vars): | ||||
| """Adds dependent opnodes and varnodes of vars into network | |||||
| """ | |||||
| oprs = get_oprs_seq(vars, False, False) | |||||
| for mge_opr in oprs: | |||||
| if len(vars) == 0: | |||||
| vars = self.output_vars | |||||
| q = list(vars) | |||||
| while len(q) > 0: | |||||
| cur = q.pop(0) | |||||
| if cur.owner is not None: | |||||
| continue | |||||
| if cur.name is None: | |||||
| cur.name = cur.var.name | |||||
| self.all_vars_map[cur.var.id] = cur | |||||
| mge_opr = cur.var.owner | |||||
| if get_opr_type(mge_opr) == "Host2DeviceCopy": | if get_opr_type(mge_opr) == "Host2DeviceCopy": | ||||
| self._orig_inputs.extend(mge_opr.outputs) | self._orig_inputs.extend(mge_opr.outputs) | ||||
| opr = self._add_opr(mge_opr) | |||||
| if opr is not None: | |||||
| for x in mge_opr.inputs: | |||||
| opr.add_inp_var(self._get_var(x)) | |||||
| # set out var | |||||
| for x in mge_opr.outputs: | |||||
| opr.add_out_var(self._get_var(x)) | |||||
| return [self.all_vars_map[var.id] for var in vars] | |||||
| cur.owner = self._add_opr(mge_opr) | |||||
| if cur.owner is None: | |||||
| cur.owner = self.all_oprs_map[mge_opr.id] | |||||
| continue | |||||
| q.extend(cur.owner.inputs) | |||||
| return list(vars) | |||||
| def modify_opr_names(self, modifier): | def modify_opr_names(self, modifier): | ||||
| """Modifies names of operators **inplace**; useful for merging loaded | """Modifies names of operators **inplace**; useful for merging loaded | ||||
| @@ -275,6 +282,9 @@ class Network: | |||||
| Replaces vars in the graph. | Replaces vars in the graph. | ||||
| :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | ||||
| """ | """ | ||||
| if not all([var.owner for var in repl_dict.values()]): | |||||
| print(repl_dict.values()) | |||||
| self.add_dep_oprs(*list(repl_dict.values())) | |||||
| for var in self.all_vars: | for var in self.all_vars: | ||||
| if var in repl_dict: | if var in repl_dict: | ||||
| repl_var = repl_dict[var] | repl_var = repl_dict[var] | ||||
| @@ -282,6 +292,7 @@ class Network: | |||||
| idx = owner.outputs.index(repl_var) | idx = owner.outputs.index(repl_var) | ||||
| owner.outputs[idx] = var | owner.outputs[idx] = var | ||||
| var.__dict__.update(repl_var.__dict__) | var.__dict__.update(repl_var.__dict__) | ||||
| var.var = repl_var.var | |||||
| def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | ||||
| """ | """ | ||||
| @@ -297,6 +308,7 @@ class Network: | |||||
| for ind, var in enumerate(opr.outputs): | for ind, var in enumerate(opr.outputs): | ||||
| var.owner = repl_dict[opr] | var.owner = repl_dict[opr] | ||||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | ||||
| var.var = repl_dict[opr].outputs[ind].var | |||||
| def get_opr_by_type(self, oprcls, unique=True): | def get_opr_by_type(self, oprcls, unique=True): | ||||
| assert issubclass(oprcls, OpNode) | assert issubclass(oprcls, OpNode) | ||||
| @@ -381,11 +393,16 @@ class Network: | |||||
| return self.opr_filter.as_dict() | return self.opr_filter.as_dict() | ||||
| # used for loading and building graph | # used for loading and building graph | ||||
| def _add_opr(self, x): | |||||
| def _add_opr(self, opr): | |||||
| # TODO: use megbrain C++ RTTI to replace type string | # TODO: use megbrain C++ RTTI to replace type string | ||||
| if x.id not in self.all_oprs_map: | |||||
| self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||||
| return self.all_oprs_map[x.id] | |||||
| if opr.id not in self.all_oprs_map: | |||||
| opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | |||||
| self.all_oprs_map[opr.id] = opnode | |||||
| for var in opr.inputs: | |||||
| opnode.add_inp_var(self._get_var(var)) | |||||
| for var in opr.outputs: | |||||
| opnode.add_out_var(self._get_var(var)) | |||||
| return opnode | |||||
| else: | else: | ||||
| return None | return None | ||||
| @@ -397,7 +414,7 @@ class Network: | |||||
| def _get_var(self, x): | def _get_var(self, x): | ||||
| # auto convert to VarNode of Network | # auto convert to VarNode of Network | ||||
| if x.id not in self.all_vars_map: | |||||
| if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | |||||
| self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | ||||
| return self.all_vars_map[x.id] | return self.all_vars_map[x.id] | ||||
| @@ -652,7 +669,7 @@ class NodeFilterHasInput(NodeFilter): | |||||
| assert isinstance( | assert isinstance( | ||||
| i, OpNode | i, OpNode | ||||
| ), "has_input() must be used with OpNode; " "got {!r}".format(i) | ), "has_input() must be used with OpNode; " "got {!r}".format(i) | ||||
| if self.var in i.inputs: | |||||
| if any(self.var is _ for _ in i.inputs): | |||||
| yield i | yield i | ||||
| @@ -6,16 +6,21 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import abc | |||||
| import json | import json | ||||
| import sys | import sys | ||||
| from typing import Callable | |||||
| from typing import Callable, Sequence | |||||
| import numpy as np | import numpy as np | ||||
| from ..core import _imperative_rt as rt | from ..core import _imperative_rt as rt | ||||
| from ..core._imperative_rt.core2 import SymbolVar | |||||
| from ..core._wrap import Device | from ..core._wrap import Device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.tensor.megbrain_graph import InputNode | |||||
| from ..core.tensor.array_method import ArrayMethodMixin | |||||
| from ..core.tensor.indexing import getitem as _getitem | |||||
| from ..core.tensor.indexing import setitem as _setitem | |||||
| from ..core.tensor.megbrain_graph import InputNode, OutputNode | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .comp_graph_tools import replace_vars | from .comp_graph_tools import replace_vars | ||||
| from .module_stats import ( | from .module_stats import ( | ||||
| @@ -29,9 +34,13 @@ class NetworkNode: | |||||
| pass | pass | ||||
| class VarNode(NetworkNode): | |||||
| def __init__(self, owner_opr=None, name=None): | |||||
| self.var = None | |||||
| class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||||
| pass | |||||
| class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
| def __init__(self, var=None, *, owner_opr=None, name=None): | |||||
| SymbolVar.__init__(self, var) | |||||
| self.owner = owner_opr | self.owner = owner_opr | ||||
| self.name = name | self.name = name | ||||
| self.id = id(self) | self.id = id(self) | ||||
| @@ -58,6 +67,40 @@ class VarNode(NetworkNode): | |||||
| def dtype(self): | def dtype(self): | ||||
| return self.var.dtype if self.var else None | return self.var.dtype if self.var else None | ||||
| def __bool__(self): | |||||
| return False | |||||
| __index__ = None | |||||
| __int__ = None | |||||
| __float__ = None | |||||
| __complex__ = None | |||||
| def __hash__(self): | |||||
| return id(self) | |||||
| @property | |||||
| def _tuple_shape(self): | |||||
| return self.var.shape | |||||
| def numpy(self): | |||||
| o = OutputNode(self.var) | |||||
| self.graph.compile(o.outputs).execute() | |||||
| return o.get_value().numpy() | |||||
| def __getitem__(self, index): | |||||
| return _getitem(self, index) | |||||
| def __setitem__(self, index, value): | |||||
| if index is not Ellipsis: | |||||
| value = _setitem(self, index, value) | |||||
| if self.owner is not None: | |||||
| idx = self.owner.outputs.index(self) | |||||
| self.owner.outputs[idx] = VarNode( | |||||
| self.var, owner_opr=self.owner, name=self.var.name | |||||
| ) | |||||
| self.var = value.var | |||||
| self.owner = None | |||||
| def set_owner_opr(self, owner_opr): | def set_owner_opr(self, owner_opr): | ||||
| self.owner = owner_opr | self.owner = owner_opr | ||||
| @@ -138,7 +181,7 @@ class Host2DeviceCopy(OpNode): | |||||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | ||||
| self._opr = outputs.owner | self._opr = outputs.owner | ||||
| if len(self.outputs) == 0: | if len(self.outputs) == 0: | ||||
| self.outputs.append(VarNode(self, self.name)) | |||||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
| self.outputs[0].var = outputs | self.outputs[0].var = outputs | ||||
| assert self.outputs[0].owner is self | assert self.outputs[0].owner is self | ||||
| @@ -176,8 +219,8 @@ class ImmutableTensor(OpNode): | |||||
| def set_value(self, data, device=None): | def set_value(self, data, device=None): | ||||
| assert self.graph is not None | assert self.graph is not None | ||||
| cn = device if device else self.device | cn = device if device else self.device | ||||
| assert isinstance(data, (int, float, np.ndarray)) | |||||
| if isinstance(data, (int, float)): | |||||
| assert isinstance(data, (int, float, Sequence, np.ndarray)) | |||||
| if not isinstance(data, np.ndarray): | |||||
| data = np.array(data) | data = np.array(data) | ||||
| if data.dtype == np.float64: | if data.dtype == np.float64: | ||||
| data = data.astype(np.float32) | data = data.astype(np.float32) | ||||
| @@ -185,7 +228,7 @@ class ImmutableTensor(OpNode): | |||||
| data = data.astype(np.int32) | data = data.astype(np.int32) | ||||
| varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | ||||
| if len(self.outputs) == 0: | if len(self.outputs) == 0: | ||||
| self.outputs.append(VarNode(self, self.name)) | |||||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
| self.outputs[0].var = varnode | self.outputs[0].var = varnode | ||||
| self._opr = varnode.owner | self._opr = varnode.owner | ||||
| @@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| if (ctx.op->same_type<BackwardGraph>()) { | if (ctx.op->same_type<BackwardGraph>()) { | ||||
| ctx.backward = true; | ctx.backward = true; | ||||
| } | } | ||||
| if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||||
| SmallVector<cg::VarNode*> vinputs(nargs); | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||||
| } | |||||
| auto op = ctx.op.get(); | |||||
| return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||||
| } | |||||
| if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | |||||
| SmallVector<cg::VarNode*> vinputs(nargs); | |||||
| for (size_t i = 0; i < nargs; ++i) { | |||||
| vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||||
| } | |||||
| auto op = ctx.op.get(); | |||||
| auto rst = OpDef::apply_on_var_node(*op, vinputs); | |||||
| auto ret = pybind11::tuple(rst.size()); | |||||
| auto typeobj = py::handle(args[0]).get_type(); | |||||
| for (size_t i = 0; i<rst.size(); ++i) { | |||||
| ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic)); | |||||
| } | |||||
| return ret.release().ptr(); | |||||
| } | |||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | ||||
| @@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||||
| auto var = py::handle(handle).cast<cg::VarNode *>(); | |||||
| mgb::DType type = var->dtype(); | |||||
| if (py::isinstance<PySymbolVar>(py::handle(handle))){ | |||||
| auto var = py::handle(handle).cast<PySymbolVar*>(); | |||||
| mgb::DType type = var->m_node->dtype(); | |||||
| auto && descr = npy::dtype_mgb2np_descr(type); | auto && descr = npy::dtype_mgb2np_descr(type); | ||||
| Py_INCREF(descr.get()); | Py_INCREF(descr.get()); | ||||
| tensors.emplace_back(descr.get()); | tensors.emplace_back(descr.get()); | ||||
| @@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||||
| bool valid = false; | bool valid = false; | ||||
| CompNode cn; | CompNode cn; | ||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||||
| PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
| TensorWrapper* tw = TensorWrapper::try_cast(handle); | TensorWrapper* tw = TensorWrapper::try_cast(handle); | ||||
| bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||||
| if (tw || is_var) { | |||||
| bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||||
| if (tw || is_symvar) { | |||||
| if (!valid) { | if (!valid) { | ||||
| cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||||
| cn = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle) | |||||
| .cast<PySymbolVar*>() | |||||
| ->m_node->comp_node(); | |||||
| valid = true; | valid = true; | ||||
| } else { | } else { | ||||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||||
| CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||||
| : py::handle(handle) | |||||
| .cast<PySymbolVar*>() | |||||
| ->m_node->comp_node(); | |||||
| if (cn1 != cn) { | if (cn1 != cn) { | ||||
| throw py::value_error(ssprintf("ambiguous device: %s vs %s", | throw py::value_error(ssprintf("ambiguous device: %s vs %s", | ||||
| cn.to_string().c_str(), cn1.to_string().c_str())); | |||||
| cn.to_string().c_str(), | |||||
| cn1.to_string().c_str())); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -849,6 +861,32 @@ void init_tensor(py::module m) { | |||||
| .def("__call__", &TensorWeakRef::operator()) | .def("__call__", &TensorWeakRef::operator()) | ||||
| .def("_use_cnt", &TensorWeakRef::_use_cnt); | .def("_use_cnt", &TensorWeakRef::_use_cnt); | ||||
| py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") | |||||
| .def_property_readonly( | |||||
| "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) | |||||
| .def_property("var", [](PySymbolVar* v) { return v->m_node; }, | |||||
| [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) | |||||
| .def_property_readonly( | |||||
| "device", | |||||
| [](PySymbolVar* v) { return v->m_node->comp_node(); }) | |||||
| .def_property_readonly( | |||||
| "graph", | |||||
| [](PySymbolVar* v) { return v->m_node->owner_graph(); }) | |||||
| .def_property_readonly( | |||||
| "shape", | |||||
| [](PySymbolVar* v) -> const TensorShape* { | |||||
| auto&& mgr = v->m_node->owner_graph() | |||||
| ->static_infer_manager(); | |||||
| return mgr.infer_shape_fallible(v->m_node); | |||||
| }) | |||||
| .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||||
| .def("_setscalar", | |||||
| [](PySymbolVar* v) { return v->is_scalar = true; }) | |||||
| .def(py::init([](cg::VarNode* node) { | |||||
| return std::make_shared<PySymbolVar>(node); | |||||
| }), | |||||
| py::arg() = nullptr); | |||||
| static PyMethodDef method_defs[] = { | static PyMethodDef method_defs[] = { | ||||
| MGE_PY_INTERFACE(apply, py_apply), | MGE_PY_INTERFACE(apply, py_apply), | ||||
| MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | ||||
| @@ -181,6 +181,12 @@ struct TensorWrapper { | |||||
| PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | ||||
| }; | }; | ||||
| struct PySymbolVar { | |||||
| cg::VarNode* m_node = nullptr; | |||||
| bool is_scalar = false; | |||||
| PySymbolVar() = default; | |||||
| PySymbolVar(VarNode *m): m_node(m){} | |||||
| }; | |||||
| PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | ||||
| @@ -2,9 +2,11 @@ import io | |||||
| import numpy as np | import numpy as np | ||||
| import megengine.core.tensor.megbrain_graph as G | |||||
| import megengine.utils.comp_graph_tools as cgtools | import megengine.utils.comp_graph_tools as cgtools | ||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.utils.network_node import VarNode | |||||
| def _default_compare_fn(x, y): | def _default_compare_fn(x, y): | ||||
| @@ -14,8 +16,23 @@ def _default_compare_fn(x, y): | |||||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | ||||
| def make_tensor(x, network=None, device=None): | |||||
| if network is not None: | |||||
| if isinstance(x, VarNode): | |||||
| return VarNode(x.var) | |||||
| return network.make_const(x, device=device) | |||||
| else: | |||||
| return tensor(x, device=device) | |||||
| def opr_test( | def opr_test( | ||||
| cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||||
| cases, | |||||
| func, | |||||
| compare_fn=_default_compare_fn, | |||||
| ref_fn=None, | |||||
| test_trace=True, | |||||
| network=None, | |||||
| **kwargs | |||||
| ): | ): | ||||
| """ | """ | ||||
| :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | :param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | ||||
| @@ -44,7 +61,7 @@ def opr_test( | |||||
| if not isinstance(results, (tuple, list)): | if not isinstance(results, (tuple, list)): | ||||
| results = (results,) | results = (results,) | ||||
| for r, e in zip(results, expected): | for r, e in zip(results, expected): | ||||
| if not isinstance(r, tensor): | |||||
| if not isinstance(r, (tensor, VarNode)): | |||||
| r = tensor(r) | r = tensor(r) | ||||
| compare_fn(r, e) | compare_fn(r, e) | ||||
| @@ -72,9 +89,9 @@ def opr_test( | |||||
| raise ValueError("the input func should be callable") | raise ValueError("the input func should be callable") | ||||
| inp, outp = get_param(cases, 0) | inp, outp = get_param(cases, 0) | ||||
| inp_tensor = [tensor(inpi) for inpi in inp] | |||||
| inp_tensor = [make_tensor(inpi, network) for inpi in inp] | |||||
| if test_trace: | |||||
| if test_trace and not network: | |||||
| copied_inp = inp_tensor.copy() | copied_inp = inp_tensor.copy() | ||||
| for symbolic in [False, True]: | for symbolic in [False, True]: | ||||
| traced_func = trace(symbolic=symbolic)(func) | traced_func = trace(symbolic=symbolic)(func) | ||||
| @@ -10,12 +10,17 @@ import collections | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| from utils import make_tensor | |||||
| import megengine | import megengine | ||||
| import megengine.core.tensor.megbrain_graph as G | |||||
| import megengine.functional as F | |||||
| from megengine.core._imperative_rt.core2 import apply | from megengine.core._imperative_rt.core2 import apply | ||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
| from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
| from megengine.utils.network import Network | |||||
| from megengine.utils.network_node import VarNode | |||||
| def cvt_to_shape_desc(val, inpvar, config=None): | def cvt_to_shape_desc(val, inpvar, config=None): | ||||
| @@ -387,108 +392,130 @@ def test_batched_mesh_indexing(): | |||||
| # high level | # high level | ||||
| def get_value(x): | |||||
| if isinstance(x, VarNode): | |||||
| var = x.var | |||||
| o = G.OutputNode(var) | |||||
| graph = x.graph | |||||
| graph.compile(o.outputs).execute() | |||||
| return o.get_value().numpy() | |||||
| else: | |||||
| return x.numpy() | |||||
| @pytest.mark.parametrize("test_varnode", [True, False]) | |||||
| def test_advance_indexing_high_level(test_varnode): | |||||
| if test_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| def test_advance_indexing_high_level(): | |||||
| x = np.arange(25).reshape(5, 5).astype("int32") | x = np.arange(25).reshape(5, 5).astype("int32") | ||||
| d = np.arange(15).reshape(3, 5).astype("int32") | d = np.arange(15).reshape(3, 5).astype("int32") | ||||
| xx = Tensor(x) | |||||
| xx = make_tensor(x, network) | |||||
| np.testing.assert_equal(x[1, :], xx[1, :].numpy()) | |||||
| np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||||
| np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy()) | |||||
| np.testing.assert_equal(x[1, :], get_value(xx[1, :])) | |||||
| np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||||
| np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) | |||||
| np.testing.assert_equal(x[:, :], xx[:, :].numpy()) | |||||
| np.testing.assert_equal(x[1, 1], xx[1, 1].numpy()) | |||||
| np.testing.assert_equal(x[:, :], get_value(xx[:, :])) | |||||
| np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) | |||||
| yy = xx[(0, 4, 2), :] | yy = xx[(0, 4, 2), :] | ||||
| np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy()) | |||||
| np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy)) | |||||
| x_ = x.copy() | x_ = x.copy() | ||||
| x_[(0, 4, 2), :] = d | x_[(0, 4, 2), :] = d | ||||
| xx_ = Tensor(xx) | |||||
| xx_ = make_tensor(xx, network) | |||||
| xx_[(0, 4, 2), :] = d | xx_[(0, 4, 2), :] = d | ||||
| np.testing.assert_equal(x_, xx_.numpy()) | |||||
| np.testing.assert_equal(x_, get_value(xx_)) | |||||
| x = np.arange(27).reshape(3, 3, 3).astype("int32") | x = np.arange(27).reshape(3, 3, 3).astype("int32") | ||||
| xx = Tensor(x) | |||||
| xx = make_tensor(x, network) | |||||
| np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy()) | |||||
| np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy()) | |||||
| np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy()) | |||||
| np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy()) | |||||
| np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy()) | |||||
| np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||||
| np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy()) | |||||
| np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :])) | |||||
| np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1])) | |||||
| np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :])) | |||||
| np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1])) | |||||
| np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1])) | |||||
| np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||||
| np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2])) | |||||
| x_ = x.copy() | x_ = x.copy() | ||||
| x_[1, 1, 1] = -1 | x_[1, 1, 1] = -1 | ||||
| xx[1, 1, 1] = -1 | xx[1, 1, 1] = -1 | ||||
| np.testing.assert_equal(x_, xx.numpy()) | |||||
| np.testing.assert_equal(x_, get_value(xx)) | |||||
| x_[:, 1, 1] = -2 | x_[:, 1, 1] = -2 | ||||
| xx[:, 1, 1] = x_[:, 1, 1] | xx[:, 1, 1] = x_[:, 1, 1] | ||||
| np.testing.assert_equal(x_, xx.numpy()) | |||||
| np.testing.assert_equal(x_, get_value(xx)) | |||||
| x_[0:1, :, 1] = -3 | x_[0:1, :, 1] = -3 | ||||
| xx[0:1, :, 1] = x_[0:1, :, 1] | xx[0:1, :, 1] = x_[0:1, :, 1] | ||||
| np.testing.assert_equal(x_, xx.numpy()) | |||||
| np.testing.assert_equal(x_, get_value(xx)) | |||||
| x_[0:1, :, 1] = -4 | x_[0:1, :, 1] = -4 | ||||
| y = Tensor(x_) | |||||
| y = make_tensor(x_, network) | |||||
| xx[0:1, :, 1] = y[0:1, :, 1] | xx[0:1, :, 1] = y[0:1, :, 1] | ||||
| np.testing.assert_equal(y.numpy(), xx.numpy()) | |||||
| np.testing.assert_equal(get_value(y), get_value(xx)) | |||||
| x[:] = 1 | x[:] = 1 | ||||
| xx[:] = 1 | xx[:] = 1 | ||||
| np.testing.assert_equal(x, xx.numpy()) | |||||
| np.testing.assert_equal(x, get_value(xx)) | |||||
| x = np.arange(9).reshape(3, 3).astype("int32") | x = np.arange(9).reshape(3, 3).astype("int32") | ||||
| xx = Tensor(x) | |||||
| xx = make_tensor(x, network) | |||||
| y = np.array([1, 2]) | y = np.array([1, 2]) | ||||
| yy = Tensor(y) | |||||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||||
| yy = make_tensor(y, network) | |||||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||||
| np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||||
| np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||||
| x_ = x.copy() | x_ = x.copy() | ||||
| x_[:, y[0]] = -1 | x_[:, y[0]] = -1 | ||||
| xx_ = Tensor(x_) | |||||
| xx_ = make_tensor(x_, network) | |||||
| xx[:, yy[0]] = xx_[:, yy[0]] | xx[:, yy[0]] = xx_[:, yy[0]] | ||||
| np.testing.assert_equal(x_, xx.numpy()) | |||||
| np.testing.assert_equal(x_, get_value(xx)) | |||||
| x_[:, y] = -1 | x_[:, y] = -1 | ||||
| xx_ = Tensor(x_) | |||||
| xx_ = make_tensor(x_, network) | |||||
| xx[:, yy] = xx_[:, yy] | xx[:, yy] = xx_[:, yy] | ||||
| np.testing.assert_equal(x_, xx.numpy()) | |||||
| np.testing.assert_equal(x_, get_value(xx)) | |||||
| x = np.arange(9).reshape(3, 3).astype("int32") | x = np.arange(9).reshape(3, 3).astype("int32") | ||||
| xx = Tensor(x) | |||||
| xx = make_tensor(x, network) | |||||
| y = np.array([1]) | y = np.array([1]) | ||||
| yy = Tensor(y) | |||||
| np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||||
| np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||||
| np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||||
| yy = make_tensor(y, network) | |||||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||||
| np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||||
| np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||||
| np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||||
| np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||||
| x = np.arange(9).reshape(3, 3).astype("int32") | x = np.arange(9).reshape(3, 3).astype("int32") | ||||
| xx = Tensor(x) | |||||
| np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy()) | |||||
| np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy()) | |||||
| def test_advance_indexing_with_bool(): | |||||
| xx = make_tensor(x, network) | |||||
| np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0])) | |||||
| np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0])) | |||||
| @pytest.mark.parametrize( | |||||
| "test_varnode", [True, False], | |||||
| ) | |||||
| def test_advance_indexing_with_bool(test_varnode): | |||||
| if test_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| a = np.arange(9).reshape(3, 3).astype(np.float32) | a = np.arange(9).reshape(3, 3).astype(np.float32) | ||||
| b = np.array([1, 2, 3]) | b = np.array([1, 2, 3]) | ||||
| c = np.array([1, 2, 3]) | c = np.array([1, 2, 3]) | ||||
| aa = Tensor(a) | |||||
| bb = Tensor(b) | |||||
| cc = Tensor(c) | |||||
| np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy()) | |||||
| aa = make_tensor(a, network) | |||||
| bb = make_tensor(b, network) | |||||
| cc = make_tensor(c, network) | |||||
| np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2])) | |||||
| a[b == 1, c == 2] = -1.0 | a[b == 1, c == 2] = -1.0 | ||||
| aa[bb == 1, cc == 2] = -1.0 | aa[bb == 1, cc == 2] = -1.0 | ||||
| np.testing.assert_equal(a, aa.numpy()) | |||||
| np.testing.assert_equal(a, get_value(aa)) | |||||
| a = np.arange(9).reshape(3, 3).astype(np.float32) | a = np.arange(9).reshape(3, 3).astype(np.float32) | ||||
| b = np.array([False, True, True]) | b = np.array([False, True, True]) | ||||
| @@ -11,13 +11,16 @@ import platform | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| from utils import opr_test | |||||
| from utils import make_tensor, opr_test | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.core.tensor import megbrain_graph as G | |||||
| from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
| from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
| from megengine.utils.network import Network | |||||
| from megengine.utils.network_node import VarNode | |||||
| def test_eye(): | def test_eye(): | ||||
| @@ -38,7 +41,13 @@ def test_eye(): | |||||
| ) | ) | ||||
| def test_concat(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_concat(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| def get_data_shape(length: int): | def get_data_shape(length: int): | ||||
| return (length, 2, 3) | return (length, 2, 3) | ||||
| @@ -50,18 +59,30 @@ def test_concat(): | |||||
| return F.concat([data1, data2]) | return F.concat([data1, data2]) | ||||
| cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | ||||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) | |||||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | |||||
| def test_concat_device(): | |||||
| data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0") | |||||
| data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1") | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_concat_device(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") | |||||
| data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") | |||||
| out = F.concat([data1, data2], device="cpu0") | out = F.concat([data1, data2], device="cpu0") | ||||
| assert str(out.device).split(":")[0] == "cpu0" | assert str(out.device).split(":")[0] == "cpu0" | ||||
| def test_stack(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_stack(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| data1 = np.random.random((3, 2, 2)).astype("float32") | data1 = np.random.random((3, 2, 2)).astype("float32") | ||||
| data2 = np.random.random((3, 2, 2)).astype("float32") | data2 = np.random.random((3, 2, 2)).astype("float32") | ||||
| data3 = np.random.random((3, 2, 2)).astype("float32") | data3 = np.random.random((3, 2, 2)).astype("float32") | ||||
| @@ -72,12 +93,20 @@ def test_stack(): | |||||
| def run(data1, data2): | def run(data1, data2): | ||||
| return F.stack([data1, data2], axis=ai) | return F.stack([data1, data2], axis=ai) | ||||
| opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai)) | |||||
| opr_test( | |||||
| cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network | |||||
| ) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_split(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| def test_split(): | |||||
| data = np.random.random((2, 3, 4, 5)).astype(np.float32) | data = np.random.random((2, 3, 4, 5)).astype(np.float32) | ||||
| inp = tensor(data) | |||||
| inp = make_tensor(data, network) | |||||
| mge_out0 = F.split(inp, 2, axis=3) | mge_out0 = F.split(inp, 2, axis=3) | ||||
| mge_out1 = F.split(inp, [3], axis=3) | mge_out1 = F.split(inp, [3], axis=3) | ||||
| @@ -106,26 +135,42 @@ def test_split(): | |||||
| assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | ||||
| def test_reshape(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_reshape(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| x = np.arange(6, dtype="float32") | x = np.arange(6, dtype="float32") | ||||
| xx = tensor(x) | |||||
| xx = make_tensor(x, network) | |||||
| y = x.reshape(1, 2, 3) | y = x.reshape(1, 2, 3) | ||||
| for shape in [ | for shape in [ | ||||
| (1, 2, 3), | (1, 2, 3), | ||||
| (1, -1, 3), | (1, -1, 3), | ||||
| (1, tensor(-1), 3), | |||||
| (1, make_tensor(-1, network), 3), | |||||
| np.array([1, -1, 3], dtype="int32"), | np.array([1, -1, 3], dtype="int32"), | ||||
| tensor([1, -1, 3]), | |||||
| make_tensor([1, -1, 3], network), | |||||
| ]: | ]: | ||||
| yy = F.reshape(xx, shape) | yy = F.reshape(xx, shape) | ||||
| np.testing.assert_equal(yy.numpy(), y) | np.testing.assert_equal(yy.numpy(), y) | ||||
| def test_reshape_shape_inference(): | |||||
| x_shape_known = tensor([1, 2, 3, 4], dtype="float32") | |||||
| x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) | |||||
| tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_reshape_shape_inference(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| x_shape_known = make_tensor([1, 2, 3, 4], network) | |||||
| x_shape_unknown = F.broadcast_to( | |||||
| make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum() | |||||
| ) | |||||
| tshp_unknown = astensor1d( | |||||
| (make_tensor([2], network), make_tensor([2], network)), x_shape_known | |||||
| ) | |||||
| tshp_known = astensor1d((2, 2), x_shape_known) | tshp_known = astensor1d((2, 2), x_shape_known) | ||||
| tshp_known_unspec = astensor1d((2, -1), x_shape_known) | tshp_known_unspec = astensor1d((2, -1), x_shape_known) | ||||
| @@ -146,12 +191,18 @@ def test_reshape_shape_inference(): | |||||
| {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, | {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, | ||||
| {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | ||||
| ] | ] | ||||
| opr_test(cases, func, compare_fn=check_shape, test_trace=True) | |||||
| opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_squeeze(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| def test_squeeze(): | |||||
| x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | ||||
| xx = tensor(x) | |||||
| xx = make_tensor(x, network) | |||||
| for axis in [None, 3, -4, (3, -4)]: | for axis in [None, 3, -4, (3, -4)]: | ||||
| y = np.squeeze(x, axis) | y = np.squeeze(x, axis) | ||||
| @@ -159,9 +210,15 @@ def test_squeeze(): | |||||
| np.testing.assert_equal(y, yy.numpy()) | np.testing.assert_equal(y, yy.numpy()) | ||||
| def test_expand_dims(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_expand_dims(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| x = np.arange(6, dtype="float32").reshape(2, 3) | x = np.arange(6, dtype="float32").reshape(2, 3) | ||||
| xx = tensor(x) | |||||
| xx = make_tensor(x, network) | |||||
| for axis in [2, -3, (3, -4), (1, -4)]: | for axis in [2, -3, (3, -4), (1, -4)]: | ||||
| y = np.expand_dims(x, axis) | y = np.expand_dims(x, axis) | ||||
| @@ -169,11 +226,17 @@ def test_expand_dims(): | |||||
| np.testing.assert_equal(y, yy.numpy()) | np.testing.assert_equal(y, yy.numpy()) | ||||
| def test_elemwise_dtype_promotion(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_elemwise_dtype_promotion(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| x = np.random.rand(2, 3).astype("float32") | x = np.random.rand(2, 3).astype("float32") | ||||
| y = np.random.rand(1, 3).astype("float16") | y = np.random.rand(1, 3).astype("float16") | ||||
| xx = tensor(x) | |||||
| yy = tensor(y) | |||||
| xx = make_tensor(x, network) | |||||
| yy = make_tensor(y, network) | |||||
| z = xx * yy | z = xx * yy | ||||
| np.testing.assert_equal(z.numpy(), x * y) | np.testing.assert_equal(z.numpy(), x * y) | ||||
| @@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion(): | |||||
| np.testing.assert_equal(z.numpy(), x - y) | np.testing.assert_equal(z.numpy(), x - y) | ||||
| def test_linspace(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_linspace(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| cases = [ | cases = [ | ||||
| {"input": [1, 9, 9]}, | {"input": [1, 9, 9]}, | ||||
| {"input": [3, 10, 8]}, | {"input": [3, 10, 8]}, | ||||
| @@ -193,6 +262,7 @@ def test_linspace(): | |||||
| cases, | cases, | ||||
| F.linspace, | F.linspace, | ||||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
| network=network, | |||||
| ) | ) | ||||
| cases = [ | cases = [ | ||||
| @@ -203,20 +273,28 @@ def test_linspace(): | |||||
| cases, | cases, | ||||
| F.linspace, | F.linspace, | ||||
| ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | ||||
| network=network, | |||||
| ) | ) | ||||
| cases = [ | cases = [ | ||||
| {"input": [1, tensor(9), 9]}, | |||||
| {"input": [tensor(1), 9, tensor(9)]}, | |||||
| {"input": [1, make_tensor(9, network), 9]}, | |||||
| {"input": [make_tensor(1, network), 9, make_tensor(9, network)]}, | |||||
| ] | ] | ||||
| opr_test( | opr_test( | ||||
| cases, | cases, | ||||
| F.linspace, | F.linspace, | ||||
| ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), | ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), | ||||
| network=network, | |||||
| ) | ) | ||||
| def test_arange(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_arange(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| cases = [ | cases = [ | ||||
| {"input": [1, 9, 1]}, | {"input": [1, 9, 1]}, | ||||
| {"input": [2, 10, 2]}, | {"input": [2, 10, 2]}, | ||||
| @@ -225,6 +303,7 @@ def test_arange(): | |||||
| cases, | cases, | ||||
| F.arange, | F.arange, | ||||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
| network=network, | |||||
| ) | ) | ||||
| cases = [ | cases = [ | ||||
| @@ -235,6 +314,7 @@ def test_arange(): | |||||
| cases, | cases, | ||||
| F.arange, | F.arange, | ||||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
| network=network, | |||||
| ) | ) | ||||
| cases = [ | cases = [ | ||||
| @@ -245,20 +325,33 @@ def test_arange(): | |||||
| cases, | cases, | ||||
| F.arange, | F.arange, | ||||
| ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | ||||
| network=network, | |||||
| ) | ) | ||||
| def test_round(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_round(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| data1_shape = (15,) | data1_shape = (15,) | ||||
| data2_shape = (25,) | data2_shape = (25,) | ||||
| data1 = np.random.random(data1_shape).astype(np.float32) | data1 = np.random.random(data1_shape).astype(np.float32) | ||||
| data2 = np.random.random(data2_shape).astype(np.float32) | data2 = np.random.random(data2_shape).astype(np.float32) | ||||
| cases = [{"input": data1}, {"input": data2}] | cases = [{"input": data1}, {"input": data2}] | ||||
| opr_test(cases, F.round, ref_fn=np.round) | |||||
| opr_test(cases, F.round, ref_fn=np.round, network=network) | |||||
| def test_flatten(): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_flatten(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| data0_shape = (2, 3, 4, 5) | data0_shape = (2, 3, 4, 5) | ||||
| data1_shape = (4, 5, 6, 7) | data1_shape = (4, 5, 6, 7) | ||||
| data0 = np.random.random(data0_shape).astype(np.float32) | data0 = np.random.random(data0_shape).astype(np.float32) | ||||
| @@ -273,7 +366,7 @@ def test_flatten(): | |||||
| {"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
| {"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
| ] | ] | ||||
| opr_test(cases, F.flatten, compare_fn=compare_fn) | |||||
| opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) | |||||
| output0 = (2, 3 * 4 * 5) | output0 = (2, 3 * 4 * 5) | ||||
| output1 = (4, 5 * 6 * 7) | output1 = (4, 5 * 6 * 7) | ||||
| @@ -281,7 +374,7 @@ def test_flatten(): | |||||
| {"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
| {"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
| ] | ] | ||||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) | |||||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) | |||||
| output0 = (2, 3, 4 * 5) | output0 = (2, 3, 4 * 5) | ||||
| output1 = (4, 5, 6 * 7) | output1 = (4, 5, 6 * 7) | ||||
| @@ -289,7 +382,7 @@ def test_flatten(): | |||||
| {"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
| {"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
| ] | ] | ||||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) | |||||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) | |||||
| output0 = (2, 3 * 4, 5) | output0 = (2, 3 * 4, 5) | ||||
| output1 = (4, 5 * 6, 7) | output1 = (4, 5 * 6, 7) | ||||
| @@ -297,10 +390,23 @@ def test_flatten(): | |||||
| {"input": data0, "output": output0}, | {"input": data0, "output": output0}, | ||||
| {"input": data1, "output": output1}, | {"input": data1, "output": output1}, | ||||
| ] | ] | ||||
| opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) | |||||
| opr_test( | |||||
| cases, | |||||
| F.flatten, | |||||
| compare_fn=compare_fn, | |||||
| start_axis=1, | |||||
| end_axis=2, | |||||
| network=network, | |||||
| ) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_broadcast(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| def test_broadcast(): | |||||
| input1_shape = (20, 30) | input1_shape = (20, 30) | ||||
| output1_shape = (30, 20, 30) | output1_shape = (30, 20, 30) | ||||
| data1 = np.random.random(input1_shape).astype(np.float32) | data1 = np.random.random(input1_shape).astype(np.float32) | ||||
| @@ -321,7 +427,7 @@ def test_broadcast(): | |||||
| {"input": [data2, output2_shape], "output": output2_shape}, | {"input": [data2, output2_shape], "output": output2_shape}, | ||||
| {"input": [data3, output3_shape], "output": output3_shape}, | {"input": [data3, output3_shape], "output": output3_shape}, | ||||
| ] | ] | ||||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||||
| opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) | |||||
| x = F.ones((2, 1, 3)) | x = F.ones((2, 1, 3)) | ||||
| with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
| @@ -334,35 +440,41 @@ def test_broadcast(): | |||||
| F.broadcast_to(x, (1, 3)) | F.broadcast_to(x, (1, 3)) | ||||
| def test_utils_astensor1d(): | |||||
| reference = tensor(0) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_utils_astensor1d(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| reference = make_tensor(0, network) | |||||
| # literal | # literal | ||||
| x = [1, 2, 3] | x = [1, 2, 3] | ||||
| for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
| xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
| assert type(xx) is tensor | |||||
| assert isinstance(xx, type(reference)) | |||||
| np.testing.assert_equal(xx.numpy(), x) | np.testing.assert_equal(xx.numpy(), x) | ||||
| # numpy array | # numpy array | ||||
| x = np.asarray([1, 2, 3], dtype="int32") | x = np.asarray([1, 2, 3], dtype="int32") | ||||
| for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
| xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
| assert type(xx) is tensor | |||||
| assert isinstance(xx, type(reference)) | |||||
| np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) | np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) | ||||
| # tensor | # tensor | ||||
| x = tensor([1, 2, 3], dtype="int32") | |||||
| x = make_tensor([1, 2, 3], network) | |||||
| for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
| xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
| assert type(xx) is tensor | |||||
| assert isinstance(xx, type(reference)) | |||||
| np.testing.assert_equal(xx.numpy(), x.numpy()) | np.testing.assert_equal(xx.numpy(), x.numpy()) | ||||
| # mixed | # mixed | ||||
| x = [1, tensor(2), 3] | |||||
| x = [1, make_tensor(2, network), 3] | |||||
| for dtype in [None, "float32"]: | for dtype in [None, "float32"]: | ||||
| xx = astensor1d(x, reference, dtype=dtype) | xx = astensor1d(x, reference, dtype=dtype) | ||||
| assert type(xx) is tensor | |||||
| assert isinstance(xx, type(reference)) | |||||
| np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | ||||
| @@ -382,35 +494,60 @@ def test_device(): | |||||
| np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) | np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) | ||||
| def test_identity(): | |||||
| x = tensor(np.random.random((5, 10)).astype(np.float32)) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_identity(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| x = make_tensor(np.random.random((5, 10)).astype(np.float32), network) | |||||
| y = F.copy(x) | y = F.copy(x) | ||||
| np.testing.assert_equal(y.numpy(), x) | np.testing.assert_equal(y.numpy(), x) | ||||
| def copy_test(dst, src): | |||||
| def copy_test(dst, src, network): | |||||
| data = np.random.random((2, 3)).astype(np.float32) | data = np.random.random((2, 3)).astype(np.float32) | ||||
| x = tensor(data, device=src) | |||||
| x = make_tensor(data, device=src, network=network) | |||||
| y = F.copy(x, dst) | y = F.copy(x, dst) | ||||
| assert np.allclose(data, y.numpy()) | assert np.allclose(data, y.numpy()) | ||||
| z = x.to(dst) | |||||
| assert np.allclose(data, z.numpy()) | |||||
| if network is None: | |||||
| z = x.to(dst) | |||||
| assert np.allclose(data, z.numpy()) | |||||
| @pytest.mark.require_ngpu(1) | @pytest.mark.require_ngpu(1) | ||||
| def test_copy_h2d(): | |||||
| copy_test("cpu0", "gpu0") | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_copy_h2d(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| copy_test("cpu0", "gpu0", network=network) | |||||
| @pytest.mark.require_ngpu(1) | @pytest.mark.require_ngpu(1) | ||||
| def test_copy_d2h(): | |||||
| copy_test("gpu0", "cpu0") | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_copy_d2h(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| copy_test("gpu0", "cpu0", network=network) | |||||
| @pytest.mark.require_ngpu(2) | @pytest.mark.require_ngpu(2) | ||||
| def test_copy_d2d(): | |||||
| copy_test("gpu0", "gpu1") | |||||
| copy_test("gpu0:0", "gpu0:1") | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_copy_d2d(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| copy_test("gpu0", "gpu1", network=network) | |||||
| copy_test("gpu0:0", "gpu0:1", network=network) | |||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| @@ -425,7 +562,13 @@ def test_copy_d2d(): | |||||
| ((), 10, None), | ((), 10, None), | ||||
| ], | ], | ||||
| ) | ) | ||||
| def test_repeat(shape, repeats, axis): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_repeat(shape, repeats, axis, is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| def repeat_func(inp): | def repeat_func(inp): | ||||
| return F.repeat(inp=inp, repeats=repeats, axis=axis) | return F.repeat(inp=inp, repeats=repeats, axis=axis) | ||||
| @@ -437,7 +580,10 @@ def test_repeat(shape, repeats, axis): | |||||
| cases = [{"input": np.array(1.23)}] | cases = [{"input": np.array(1.23)}] | ||||
| opr_test( | opr_test( | ||||
| cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||||
| cases, | |||||
| repeat_func, | |||||
| ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||||
| network=network, | |||||
| ) | ) | ||||
| @@ -450,14 +596,16 @@ def test_repeat(shape, repeats, axis): | |||||
| ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | ||||
| ], | ], | ||||
| ) | ) | ||||
| def test_tile(shape, reps): | |||||
| @pytest.mark.parametrize("is_varnode", [True]) | |||||
| def test_tile(shape, reps, is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| def tile_func(inp): | def tile_func(inp): | ||||
| return F.tile(inp=inp, reps=reps) | return F.tile(inp=inp, reps=reps) | ||||
| cases = [ | |||||
| {"input": np.random.randn(*shape).astype("float32")}, | |||||
| ] | |||||
| cases = [{"input": np.random.randn(*shape).astype("float32")}] | |||||
| opr_test( | |||||
| cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||||
| ) | |||||
| opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network) | |||||
| @@ -34,13 +34,11 @@ def test_replace_var(): | |||||
| vara = graph.var_filter.name("a").as_unique() | vara = graph.var_filter.name("a").as_unique() | ||||
| varb = graph.var_filter.name("b").as_unique() | varb = graph.var_filter.name("b").as_unique() | ||||
| out = F.mul(vara.var, varb.var) | |||||
| out = F.mul(vara, varb) | |||||
| out = F.relu(out) | out = F.relu(out) | ||||
| var_list = graph.add_dep_oprs(out) | |||||
| opnode = list(graph.opr_filter.has_input(vara)) | opnode = list(graph.opr_filter.has_input(vara)) | ||||
| repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||||
| repl_dict = {opnode[0].outputs[0]: out} | |||||
| graph.replace_vars(repl_dict) | graph.replace_vars(repl_dict) | ||||
| modified_model = io.BytesIO() | modified_model = io.BytesIO() | ||||
| @@ -72,14 +70,12 @@ def test_replace_opr(): | |||||
| vara = graph.var_filter.name("a").as_unique() | vara = graph.var_filter.name("a").as_unique() | ||||
| varb = graph.var_filter.name("b").as_unique() | varb = graph.var_filter.name("b").as_unique() | ||||
| out1 = F.sub(vara.var, varb.var) | |||||
| out1 = F.sub(vara, varb) | |||||
| out1 = F.relu(out1) | out1 = F.relu(out1) | ||||
| var_list = graph.add_dep_oprs(out1) | |||||
| repl_opr = as_oprnode(var_list) | |||||
| out1 = graph.add_dep_oprs(out1) | |||||
| orig_opr = graph.opr_filter.has_input(vara).as_unique() | orig_opr = graph.opr_filter.has_input(vara).as_unique() | ||||
| repl_dict = {orig_opr: repl_opr} | |||||
| repl_dict = {orig_opr: out1[0].owner} | |||||
| graph.replace_oprs(repl_dict) | graph.replace_oprs(repl_dict) | ||||
| modified_model1 = io.BytesIO() | modified_model1 = io.BytesIO() | ||||
| graph.dump(modified_model1) | graph.dump(modified_model1) | ||||
| @@ -171,8 +167,7 @@ def test_add_input(): | |||||
| inp_c = graph.make_input_node((2,), np.int32, name="c") | inp_c = graph.make_input_node((2,), np.int32, name="c") | ||||
| varo = graph.var_filter.name("o").as_unique() | varo = graph.var_filter.name("o").as_unique() | ||||
| out = F.add(varo.var, inp_c.var) | |||||
| out = graph.add_dep_oprs(out)[0] | |||||
| out = F.add(varo, inp_c) | |||||
| out.name = "o1" | out.name = "o1" | ||||
| graph.remove_output(varo) | graph.remove_output(varo) | ||||
| graph.add_output(out) | graph.add_output(out) | ||||
| @@ -206,12 +201,11 @@ def test_add_output(): | |||||
| var_a = net.var_filter.name("a").as_unique() | var_a = net.var_filter.name("a").as_unique() | ||||
| var_b = net.var_filter.name("b").as_unique() | var_b = net.var_filter.name("b").as_unique() | ||||
| y = F.add(var_a.var, var_b.var) | |||||
| y = F.add(var_a, var_b) | |||||
| y = F.sigmoid(y) | y = F.sigmoid(y) | ||||
| new_vars = net.add_dep_oprs(y)[0] | |||||
| new_vars.name = "o1" | |||||
| net.add_output(new_vars) | |||||
| y.name = "o1" | |||||
| net.add_output(y) | |||||
| modified_model = io.BytesIO() | modified_model = io.BytesIO() | ||||
| net.dump(modified_model) | net.dump(modified_model) | ||||