remove core.tensor, raw_tensor,TensorWrapper
avoid create tensor with zero-stride numpy ndarray
GitOrigin-RevId: 4fe5c4c5ba
tags/v1.2.0
| @@ -9,5 +9,4 @@ | |||
| import os | |||
| import sys | |||
| from .tensor import Tensor | |||
| from .tensor.megbrain_graph import Graph | |||
| @@ -27,8 +27,6 @@ from ..ops.builtin import ( | |||
| from ..ops.special import Const | |||
| from ..tensor.core import apply | |||
| from ..tensor.function import Function | |||
| from ..tensor.tensor import Tensor | |||
| from ..tensor.tensor_wrapper import TensorWrapper | |||
| @functools.singledispatch | |||
| @@ -21,7 +21,6 @@ from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||
| from ..ops.special import Const | |||
| from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
| from ..tensor.function import Function | |||
| from ..tensor.tensor import Tensor, get_context | |||
| from . import builtin_op_utils | |||
| """ Some notes: | |||
| @@ -65,238 +64,6 @@ def get_tensor(x): | |||
| return get_tensor(x) | |||
| class Grad: | |||
| def __init__(self, name=None): | |||
| if name is None: | |||
| global _grad_count | |||
| self._name = "grad_" + str(_grad_count) | |||
| _grad_count += 1 | |||
| else: | |||
| self._name = name | |||
| assert self._name not in _grad_manager_dict, "grad manager name duplicated" | |||
| _grad_manager_dict[self._name] = self | |||
| # list of all x in partial(y) / partial(x) | |||
| self.xs = [] | |||
| # constains weak reference of all OpNode during forward | |||
| # OpNode contains inputs, outputs and its backward | |||
| # ops forms the computational graph | |||
| self.ops = [] | |||
| # save remote_send output for backward | |||
| self.remote_send_cache = [] | |||
| self._attached_tensors = weakref.WeakSet() | |||
| self._enabled = True | |||
| @property | |||
| def name(self): | |||
| return self._name | |||
| def wrt(self, *args: Tensor, callback=None): | |||
| """ Indicates the loss is a function of the input tensors (usually the net trainable parameters), | |||
| i.e., d (loss) / d (Tensor) != 0 | |||
| callback is used to perform additional operations after gradient is obtained in backward. | |||
| e.g., copy the grad to a particular place | |||
| A VariableNode will be created and saved in the tensor/s _extra_data slot. | |||
| """ | |||
| for x in map(get_tensor, args): | |||
| v = self._new_variable(x, callback=callback) | |||
| assert self not in x._extra_data | |||
| x._extra_data[self] = Tracer(v) | |||
| self.xs.append(v) | |||
| return self | |||
| def _new_variable(self, owner, opnode=None, callback=None): | |||
| self._attached_tensors.add(owner) | |||
| return VariableNode(self, owner, opnode=opnode, callback=callback) | |||
| def _new_opnode(self, inputs, outputs): | |||
| inputs = tuple(inputs) | |||
| for i in inputs: | |||
| assert i is None or isinstance(i, VariableNode) | |||
| o = OpNode() | |||
| o.inputs = inputs | |||
| o.outputs = [] | |||
| tracers = [] | |||
| for i in outputs: | |||
| assert isinstance(i, Tensor) | |||
| v = self._new_variable(i, o) | |||
| o.outputs.append(weakref.ref(v)) | |||
| tracers.append(Tracer(v)) | |||
| self.ops.append(weakref.ref(o)) | |||
| return o, tracers | |||
| def copy(self): | |||
| raise NotImplementedError | |||
| def __enter__(self): | |||
| return self | |||
| def _exit(self): | |||
| """clear all resources""" | |||
| self._enabled = False | |||
| for o in self.ops: | |||
| o = o() | |||
| if o: | |||
| o.clear() | |||
| for i in self._attached_tensors: | |||
| i._extra_data.pop(self, None) | |||
| self.remote_send_cache = [] | |||
| def __exit__(self, *_): | |||
| self._exit() | |||
| def __call__(self, ys, dys): | |||
| """ Defines Grad(). | |||
| :param ys: outputs of forward operators, e.g., the loss tensor | |||
| :type ys: list of Tensor or TensorWrapperBase | |||
| :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss, | |||
| e.g., one for the loss itself | |||
| :type dys: list of Tensor or TensorWrapperBase | |||
| """ | |||
| assert self._enabled | |||
| self._enabled = False | |||
| def check_wrapper(): | |||
| if isinstance(dys, TensorWrapperBase): | |||
| return type(dys) | |||
| if isinstance(dys, TensorBase): | |||
| return | |||
| assert isinstance(dys, (tuple, list)) | |||
| for i in dys: | |||
| if isinstance(i, TensorWrapperBase): | |||
| return type(i) | |||
| # use Tensor as defualt wrapper | |||
| return mge.Tensor | |||
| Wrapper = check_wrapper() | |||
| def aslist(x): | |||
| if isinstance(x, (Tensor, TensorWrapperBase)): | |||
| x = [x] | |||
| else: | |||
| x = list(x) | |||
| x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x] | |||
| for i in x: | |||
| assert isinstance(i, Tensor) | |||
| return x | |||
| ys = aslist(ys) | |||
| dys = aslist(dys) | |||
| assert len(ys) == len(dys) | |||
| ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | |||
| ys = [y for i, y in enumerate(ys) if i in ids] | |||
| dys = [dy for i, dy in enumerate(dys) if i in ids] | |||
| # ys is changed to a list of VariableNode which contains more information | |||
| # such as OpNode, callback, etc. | |||
| ys = [i._extra_data[self].node for i in ys] | |||
| # NOTE: callback is called only if grad is not None | |||
| # the OpNode sequence in backward | |||
| op_seq = [] | |||
| # VariableNode -> (i, j), where i is time stamp in backward, j means jth input | |||
| last_written_to = {} | |||
| def schedule(): | |||
| reached = set(ys) | |||
| # i is the time stamp in backward | |||
| i = 0 | |||
| for o in self.ops[::-1]: | |||
| o = o() | |||
| if o is None: | |||
| continue | |||
| if not o.has_grad_fn(o, reached): | |||
| continue | |||
| op_seq.append(o) | |||
| for j, v in enumerate(o.inputs): | |||
| reached.add(v) | |||
| last_written_to[v] = i, j | |||
| i += 1 | |||
| schedule() | |||
| # VariableNode -> Tensor | |||
| cache = {} | |||
| def initialize(): | |||
| for y, dy in zip(ys, dys): | |||
| cache[y] = dy | |||
| if y not in last_written_to and y.callback: | |||
| y.callback(y.owner(), dy) | |||
| initialize() | |||
| # NOTE: None is used to mark a node has been consumed | |||
| for seqno, opnode in enumerate(op_seq): | |||
| input_nodes = opnode.inputs | |||
| output_nodes = [i() for i in opnode.outputs] | |||
| backward = opnode.backward | |||
| backward_allow_noinput = opnode.backward_allow_noinput | |||
| opnode.clear() | |||
| output_grads = [] | |||
| for i in output_nodes: | |||
| if i is not None: | |||
| if i in cache: | |||
| assert cache[i] is not None | |||
| output_grads.append(cache[i]) | |||
| else: | |||
| output_grads.append(None) | |||
| # read by backward, mark consumed | |||
| cache[i] = None | |||
| else: | |||
| output_grads.append(None) | |||
| if ( | |||
| any([grad is not None for grad in output_grads]) | |||
| or backward_allow_noinput | |||
| ): | |||
| input_grads = backward(*output_grads) | |||
| else: | |||
| input_grads = [None] * len(input_nodes) | |||
| assert len(input_nodes) == len(input_grads) | |||
| for i, (v, g) in enumerate(zip(input_nodes, input_grads)): | |||
| if v is None: | |||
| continue | |||
| if v in cache: | |||
| assert cache[v] | |||
| if g is not None: | |||
| cache[v] = add(cache[v], g) | |||
| elif g is not None: | |||
| cache[v] = g | |||
| if last_written_to[v] == (seqno, i): | |||
| if v.callback: | |||
| v.callback( | |||
| v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | |||
| ) | |||
| if v.opnode is None: | |||
| # won't read by backward, mark consumed | |||
| cache[v] = None | |||
| for v in cache.values(): | |||
| assert v is None | |||
| self._exit() | |||
| def __del__(self): | |||
| self._exit() | |||
| class clearable: | |||
| __cleared = False | |||
| @@ -10,11 +10,6 @@ import warnings | |||
| from typing import Union | |||
| from ..._imperative_rt import OpDef, ops | |||
| from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| # register OpDef as a "virtual subclass" of OpBase, so any of registered | |||
| # apply(OpBase, ...) rules could work well on OpDef | |||
| OpBase.register(OpDef) | |||
| __all__ = ["OpDef"] | |||
| @@ -6,4 +6,3 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .tensor_wrapper import TensorWrapper as Tensor | |||
| @@ -13,17 +13,9 @@ import sys | |||
| import typing | |||
| from abc import ABC | |||
| from .._imperative_rt.core2 import apply as apply2 | |||
| from .multipledispatch import Dispatcher | |||
| def apply_op(op, *args): | |||
| Wrapper = type(args[0]) | |||
| args = [arg._tensor for arg in args] | |||
| results = apply2(op, *args) | |||
| return tuple(map(Wrapper, results)) | |||
| class OpBase(ABC): | |||
| def __call__(self, *args): | |||
| return apply(self, *args) | |||
| @@ -7,9 +7,6 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..ops.builtin import OpDef | |||
| from .core import TensorBase, TensorWrapperBase, apply | |||
| from .raw_tensor import RawTensor | |||
| from .tensor import Tensor, push_context | |||
| from .tensor_wrapper import TensorWrapper | |||
| class Function: | |||
| @@ -155,13 +152,3 @@ def _(op: Function, *args: TensorWrapperBase): | |||
| t._extra_data[k] = i | |||
| return tuple(map(Wrapper, outputs)) | |||
| @apply.register() | |||
| def _(op: Function, *args: Tensor): | |||
| raise NotImplementedError | |||
| @apply.register() | |||
| def _(op: Function, *args: RawTensor): | |||
| raise NotImplementedError | |||
| @@ -1,117 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| import copy | |||
| from .core import Dispatcher, OpBase, TensorBase, apply | |||
| class Tensor(TensorBase): | |||
| def __init__(self, data: TensorBase): | |||
| self._data = data | |||
| # _extra_data is set up in Grad.wrt | |||
| self._extra_data = {} | |||
| self._user_data = {} | |||
| def __getattr__(self, name): | |||
| if name in self._user_data: | |||
| return self._user_data[name] | |||
| raise AttributeError(name) | |||
| def reset(self, other): | |||
| assert isinstance(other, __class__) | |||
| self.__dict__.clear() | |||
| self._data = other.data | |||
| self._extra_data = other._extra_data.copy() | |||
| self._user_data = other._user_data.copy() | |||
| def copy(self): | |||
| other = object.__new__(type(self)) | |||
| other.reset(self) | |||
| return other | |||
| # tensor interface | |||
| @property | |||
| def shape(self): | |||
| return self._data.shape | |||
| @property | |||
| def dtype(self): | |||
| return self._data.dtype | |||
| @property | |||
| def device(self): | |||
| return self._data.device | |||
| def numpy(self): | |||
| return self._data.numpy() | |||
| def _drop(self): | |||
| self._data._drop() | |||
| def _swap_in(self): | |||
| self._data._swap_in() | |||
| def _swap_out(self): | |||
| self._data._swap_out() | |||
| class ApplyContext: | |||
| __slots__ = ("inputs", "outputs", "key") | |||
| def __init__(self): | |||
| self.inputs = None | |||
| self.outputs = None | |||
| self.key = None | |||
| _context = None | |||
| @contextlib.contextmanager | |||
| def push_context(): | |||
| global _context | |||
| backup = _context | |||
| try: | |||
| _context = ApplyContext() | |||
| yield _context | |||
| finally: | |||
| _context = backup | |||
| def get_context(): | |||
| return _context | |||
| @apply.register() | |||
| def tensor_apply(op: OpBase, *args: Tensor): | |||
| data = tuple(i._data for i in args) | |||
| # type(Tensor._data) is RawTensor | |||
| # dispached to apply.add@RawTensor.py if passed Tensor args | |||
| outputs = apply(op, *data) | |||
| ret = tuple(map(Tensor, outputs)) | |||
| with push_context() as ctx: | |||
| ctx.inputs = args | |||
| ctx.outputs = ret | |||
| for k in set().union(*(i._extra_data for i in args)): | |||
| ctx.key = k | |||
| data = tuple( | |||
| i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args | |||
| ) | |||
| # data are instances of Tracer | |||
| # dispatched to apply.add@grad.py | |||
| outputs = apply(op, *data) | |||
| if outputs is not None: | |||
| assert len(outputs) == len(ret) | |||
| for t, i in zip(ret, outputs): | |||
| t._extra_data[k] = i | |||
| return ret | |||
| @@ -19,7 +19,6 @@ from ..ops import builtin | |||
| from ..ops.builtin import Elemwise, GetVarShape | |||
| from ..ops.special import Const | |||
| from . import utils | |||
| from .core import OpBase, TensorBase, TensorWrapperBase | |||
| from .indexing import getitem as _getitem | |||
| from .indexing import setitem as _setitem | |||
| from .utils import isscalar | |||
| @@ -439,98 +438,3 @@ class ArrayMethodMixin(abc.ABC): | |||
| min = _reduce("MIN") | |||
| max = _reduce("MAX") | |||
| mean = _reduce("MEAN") | |||
| class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||
| def __init__(self, data): | |||
| self.__wrapped__ = data | |||
| def _reset(self, other): | |||
| if not isinstance(other, __class__): | |||
| raise TypeError(type(other)) | |||
| self.__wrapped__ = other.__wrapped__ | |||
| return self | |||
| @property | |||
| def dtype(self): | |||
| return self.__wrapped__.dtype | |||
| @property | |||
| def shape(self): | |||
| shape = self.__wrapped__.shape | |||
| if shape == () or not use_symbolic_shape(): | |||
| return shape | |||
| return apply(GetVarShape(), self)[0] | |||
| @property | |||
| def device(self): | |||
| return self.__wrapped__.device | |||
| def numpy(self): | |||
| return self.__wrapped__.numpy() | |||
| def _drop(self): | |||
| self.__wrapped__._drop() | |||
| def _swap_in(self): | |||
| self.__wrapped__._swap_in() | |||
| def _swap_out(self): | |||
| self.__wrapped__._swap_out() | |||
| class TensorWrapper(ArrayMethodMixin, TensorBase): | |||
| def __init__(self, data, dtype=None, device=None, isscalar=False): | |||
| self._isscalar = isscalar | |||
| if isinstance(data, Tensor): | |||
| self._tensor = data | |||
| else: | |||
| if device is None: | |||
| device = CompNode._get_default_device() | |||
| self._tensor = Tensor(data, dtype, device) | |||
| def _reset(self, other): | |||
| if not isinstance(other, __class__): | |||
| raise TypeError(type(other)) | |||
| self._tensor = other._tensor | |||
| return self | |||
| @property | |||
| def dtype(self): | |||
| return self._tensor.dtype | |||
| @property | |||
| def shape(self): | |||
| if self._isscalar: | |||
| return () | |||
| shape = self._tensor.shape | |||
| if shape == () or not use_symbolic_shape(): | |||
| return shape | |||
| return apply(GetVarShape(), self)[0] | |||
| @property | |||
| def device(self): | |||
| return self._tensor.device | |||
| def numpy(self): | |||
| if self._isscalar: | |||
| return self._tensor.numpy().squeeze() | |||
| return self._tensor.numpy() | |||
| def _drop(self): | |||
| self._tensor._drop() | |||
| def _swap_in(self): | |||
| self._tensor._swap_in() | |||
| def _swap_out(self): | |||
| self._tensor._swap_out() | |||
| def __repr__(self): | |||
| piece = "Tensor(" | |||
| with np.printoptions(precision=4, suppress=True): | |||
| piece += "{}".format(str(self.numpy())) | |||
| if self.dtype != np.float32: | |||
| piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||
| piece += ", device={}".format(self.device) + ")" | |||
| return piece | |||
| @@ -18,9 +18,8 @@ from ..core.autodiff.grad import ( | |||
| tracer_apply, | |||
| ) | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.tensor.tensor import Tensor, tensor_apply | |||
| from ..device import get_default_device | |||
| from ..tensor import tensor | |||
| from ..tensor import Tensor | |||
| from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
| __all__ = [ | |||
| @@ -16,7 +16,6 @@ from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import utils | |||
| from ..core.tensor.core import TensorBase, TensorWrapperBase | |||
| from ..tensor import Tensor | |||
| from .elemwise import clip, exp, log, log1p | |||
| from .tensor import reshape, squeeze | |||
| @@ -703,7 +702,7 @@ def topk( | |||
| mode = "VALUE_IDX_SORTED" | |||
| op = builtin.TopK(mode=mode) | |||
| if not isinstance(k, (TensorBase, TensorWrapperBase)): | |||
| if not isinstance(k, Tensor): | |||
| (k,) = Const(k, dtype="int32", device=inp.device)(inp) | |||
| if len(inp.shape) == 1: | |||
| @@ -14,7 +14,7 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import Tensor, apply | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._wrap import device as as_device | |||
| from ..core.ops import builtin | |||
| from ..core.ops.special import Const | |||
| @@ -19,6 +19,7 @@ import weakref | |||
| import numpy as np | |||
| from ..core._imperative_rt import GraphProfiler | |||
| from ..core._imperative_rt.core2 import Tensor | |||
| from ..core._imperative_rt.ops import ( | |||
| CollectiveComm, | |||
| GaussianRNG, | |||
| @@ -32,7 +33,6 @@ from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
| from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | |||
| from ..core.tensor.tensor import Tensor | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| @@ -10,7 +10,6 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..core.tensor.tensor import Tensor | |||
| from ..tensor import Parameter, tensor | |||
| from .optimizer import Optimizer | |||
| @@ -10,7 +10,6 @@ from typing import Iterable, Union | |||
| import numpy as np | |||
| from ..core.tensor.tensor import Tensor | |||
| from ..tensor import Parameter, tensor | |||
| from .optimizer import Optimizer | |||
| @@ -8,7 +8,6 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Tuple, Union | |||
| from ..core.tensor.tensor import Tensor | |||
| from ..tensor import Parameter, tensor | |||
| from .optimizer import Optimizer | |||
| @@ -8,7 +8,6 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable, Union | |||
| from ..core.tensor.tensor import Tensor | |||
| from ..tensor import Parameter, tensor | |||
| from .optimizer import Optimizer | |||
| @@ -16,8 +16,8 @@ from .core._imperative_rt import CompNode | |||
| from .core._imperative_rt.core2 import Tensor as _Tensor | |||
| from .core._imperative_rt.core2 import apply | |||
| from .core._trace_option import use_symbolic_shape | |||
| from .core._wrap import device as as_device | |||
| from .core.ops.builtin import Copy, GetVarShape | |||
| from .core.tensor.raw_tensor import as_device | |||
| from .core.tensor.tensor_wrapper import ArrayMethodMixin | |||
| from .device import _valid_device, get_default_device | |||
| from .utils.deprecation import deprecated | |||
| @@ -43,6 +43,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| if isinstance(data, _Tensor): | |||
| obj = _Tensor.__new__(cls, data) | |||
| else: | |||
| if isinstance(data, np.ndarray): | |||
| if 0 in data.strides: | |||
| data = data.squeeze().reshape(data.shape) | |||
| obj = _Tensor.__new__(cls, data, dtype, cn) | |||
| return obj | |||
| @@ -13,7 +13,7 @@ import numpy | |||
| from ..core import _imperative_rt | |||
| from ..core._imperative_rt import OperatorNode, VarNode | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.raw_tensor import as_raw_tensor | |||
| from ..tensor import Tensor | |||
| __all__ = [ | |||
| "get_dep_vars", | |||
| @@ -309,7 +309,7 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n | |||
| cg = new_out_list[0].graph | |||
| func = cg.compile(new_out_list) | |||
| for node, value in zip(inp_node_list, inp_data_list): | |||
| node.set_value(as_raw_tensor(value)._dev_tensor()) | |||
| node.set_value(Tensor(value)._dev_tensor()) | |||
| func.execute() | |||
| out_data_list = [o.get_value().numpy() for o in out_node_list] | |||
| return out_data_list | |||
| @@ -13,7 +13,7 @@ import megengine.functional as F | |||
| from megengine import Parameter, optimizer | |||
| from megengine.jit import trace | |||
| from megengine.module import Linear, Module | |||
| from megengine.tensor import tensor | |||
| from megengine.tensor import Tensor | |||
| class MLP(Module): | |||
| @@ -54,7 +54,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| for group in opt.param_groups: | |||
| group["lr"] += 0.01 | |||
| check_func.lr += 0.01 | |||
| data = tensor(np.random.random(data_shape).astype(np.float32)) | |||
| data = Tensor(np.random.random(data_shape).astype(np.float32)) | |||
| opt.clear_grad() | |||
| with gm: | |||
| @@ -98,7 +98,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
| ori_params[param] = np.copy(param.numpy()) | |||
| train_func( | |||
| tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm | |||
| Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm | |||
| ) | |||
| step += 1 | |||
| check_func(ori_params, net.parameters(), step) | |||
| @@ -11,7 +11,7 @@ import pickle | |||
| import numpy as np | |||
| from megengine.core.tensor.dtype import bfloat16 | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| from megengine.tensor import Tensor | |||
| def test_define(): | |||
| @@ -42,14 +42,14 @@ def test_cast(): | |||
| def test_shared_nd(): | |||
| data = np.array([-3.4, 1.394683, 2.323497, -7.439948, -5.2397], dtype=bfloat16) | |||
| snd = as_raw_tensor(data, dtype=bfloat16, device="xpux") | |||
| snd = Tensor(data, dtype=bfloat16, device="xpux") | |||
| assert snd.numpy().dtype == bfloat16 | |||
| np.testing.assert_allclose( | |||
| snd.numpy(), [-3.40625, 1.398438, 2.328125, -7.4375, -5.25], atol=1e-6 | |||
| ) | |||
| data = np.array([-9.34964, -8.342, 9.4385, 0.18746, 1.48], dtype=bfloat16) | |||
| snd = as_raw_tensor(data, dtype=bfloat16, device="xpux") | |||
| snd = Tensor(data, dtype=bfloat16, device="xpux") | |||
| np.testing.assert_allclose( | |||
| snd.numpy(), [-9.375, -8.3125, 9.4375, 0.1875, 1.476562], atol=1e-6 | |||
| ) | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| import pytest | |||
| from megengine.core.tensor.dtype import intb1, intb2, intb4 | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| from megengine.tensor import Tensor | |||
| def bit_define_test(bit, low_bit_type): | |||
| @@ -78,11 +78,11 @@ def _shared_nd_test(bit, low_bit_type): | |||
| min_value = 1 - (1 << bit) | |||
| data = np.arange(min_value, max_value + 2, 2, dtype=low_bit_type) | |||
| snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux") | |||
| snd = Tensor(data, dtype=low_bit_type, device="xpux") | |||
| np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 2)) | |||
| data = np.arange(min_value, max_value + 2, 4, dtype=low_bit_type) | |||
| snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux") | |||
| snd = Tensor(data, dtype=low_bit_type, device="xpux") | |||
| np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 4)) | |||
| @@ -32,8 +32,8 @@ from megengine.core.tensor.dtype import ( | |||
| quint4, | |||
| quint8, | |||
| ) | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.tensor import Tensor | |||
| def test_dtype_quint8(): | |||
| @@ -71,7 +71,7 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None): | |||
| temp_rst = calc_func(inp_node.outputs[0]) | |||
| oup_node = G.OutputNode(temp_rst) | |||
| func = graph.compile(oup_node.outputs[0]) | |||
| inp_node.set_value(as_raw_tensor(inp, dtype=dtype, device=device)._dev_tensor()) | |||
| inp_node.set_value(Tensor(inp, dtype=dtype, device=device)._dev_tensor()) | |||
| func.execute() | |||
| return oup_node.get_value().numpy() | |||
| @@ -9,15 +9,15 @@ | |||
| import numpy as np | |||
| import pytest | |||
| import megengine.core.tensor.raw_tensor | |||
| from megengine.core.tensor.core import apply | |||
| import megengine | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.tensor import Tensor | |||
| def elemwise(*args, mode): | |||
| from megengine.core._imperative_rt.imperative import apply_op | |||
| from megengine.core.ops.builtin import Elemwise | |||
| return apply_op(Elemwise(mode), args) | |||
| return apply(Elemwise(mode), *args) | |||
| def test_basic_interface(): | |||
| @@ -44,11 +44,11 @@ def test_simple_arith(): | |||
| from megengine.core.ops.builtin import Elemwise | |||
| x = np.random.rand(10).astype("float32") | |||
| xx = megengine.core._imperative_rt.put(x) | |||
| xx = Tensor(x) | |||
| (yy,) = elemwise(xx, xx, mode=Elemwise.Mode.MUL) | |||
| np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy)) | |||
| megengine.core._imperative_rt.delete(xx) | |||
| megengine.core._imperative_rt.delete(yy) | |||
| np.testing.assert_allclose(x * x, yy.numpy()) | |||
| del xx | |||
| del yy | |||
| def test_tensor_on_device(): | |||
| @@ -62,10 +62,9 @@ def test_tensor_on_device(): | |||
| def test_raw_tensor(): | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| x = np.random.rand(10).astype("float32") | |||
| xx = as_raw_tensor(x) | |||
| xx = Tensor(x) | |||
| (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) | |||
| np.testing.assert_allclose(x * x, yy.numpy()) | |||
| (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) | |||
| @@ -12,10 +12,10 @@ import numpy as np | |||
| import pytest | |||
| import megengine | |||
| import megengine.tensor as Tensor | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.ops import builtin | |||
| from megengine.tensor import Tensor | |||
| def cvt_to_shape_desc(val, inpvar, config=None): | |||
| @@ -8,8 +8,6 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import pytest | |||
| from megengine.core import Tensor | |||
| # from megengine.core.interpreter.hints import function | |||
| @@ -11,8 +11,8 @@ from concurrent.futures import Future | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| import megengine.tensor as Tensor | |||
| from megengine.core.tensor import megbrain_graph as mgb_graph | |||
| from megengine.tensor import Tensor | |||
| def test_io(): | |||
| @@ -9,12 +9,12 @@ | |||
| import numpy as np | |||
| import megengine.functional as F | |||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
| from megengine.tensor import Tensor | |||
| def test_as_raw_tensor(): | |||
| x = np.arange(6, dtype="float32").reshape(2, 3) | |||
| xx = as_raw_tensor(x, device="xpux") | |||
| xx = Tensor(x, device="xpux") | |||
| yy = F.add(xx, 1).numpy() | |||
| assert xx.dtype == np.float32 | |||
| assert xx.device == "xpux" | |||
| @@ -23,7 +23,7 @@ def test_as_raw_tensor(): | |||
| def test_as_raw_tensor_from_int64(): | |||
| x = np.arange(6, dtype="int64").reshape(2, 3) | |||
| xx = as_raw_tensor(x, dtype="float32", device="xpux") | |||
| xx = Tensor(x, dtype="float32", device="xpux") | |||
| yy = F.add(xx, 1).numpy() | |||
| assert xx.dtype == np.float32 | |||
| assert xx.device == "xpux" | |||