| @@ -1,40 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from .._imperative_rt import make_const | |||||
| from .._imperative_rt.core2 import SymbolVar, Tensor | |||||
| class Const: | |||||
| def __init__(self, value=None, *, dtype=None, device=None): | |||||
| self.value = np.asarray(value, dtype=dtype) | |||||
| self.dtype = dtype | |||||
| self.device = device | |||||
| def __call__(self, *reference): | |||||
| from ...tensor import Tensor | |||||
| device = self.device | |||||
| if len(reference) != 0: | |||||
| reference = reference[0] | |||||
| assert isinstance( | |||||
| reference, (SymbolVar, Tensor) | |||||
| ), "Reference should be Tensor or VarNode" | |||||
| if device is None: | |||||
| device = reference.device | |||||
| if isinstance(reference, SymbolVar): | |||||
| cls = type(reference) | |||||
| rst = cls(make_const(reference.graph, self.value, device, self.dtype)) | |||||
| return (rst,) | |||||
| return (Tensor(self.value, self.dtype, self.device, True),) | |||||
| @@ -14,6 +14,7 @@ import numpy as np | |||||
| from .._imperative_rt import make_const | from .._imperative_rt import make_const | ||||
| from .._imperative_rt.core2 import ( | from .._imperative_rt.core2 import ( | ||||
| Const, | |||||
| SymbolVar, | SymbolVar, | ||||
| Tensor, | Tensor, | ||||
| _get_convert_inputs, | _get_convert_inputs, | ||||
| @@ -28,7 +29,6 @@ from .._imperative_rt.ops import jit_supported | |||||
| from .._wrap import as_device | from .._wrap import as_device | ||||
| from ..autodiff.grad import Function | from ..autodiff.grad import Function | ||||
| from ..ops import builtin | from ..ops import builtin | ||||
| from ..ops.special import Const | |||||
| from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype | ||||
| from .dtype import is_dtype_equal, is_quantize | from .dtype import is_dtype_equal, is_quantize | ||||
| @@ -67,7 +67,7 @@ def convert_single_value(v, *, dtype=None, device=None): | |||||
| if not is_quantize(v.dtype): | if not is_quantize(v.dtype): | ||||
| v = astype(v, dtype) | v = astype(v, dtype) | ||||
| else: | else: | ||||
| (v,) = Const(v, dtype=dtype, device=device)() | |||||
| v = Const(v, dtype, device, None) | |||||
| return v | return v | ||||
| @@ -155,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| if ndim != 0 and ndim != 1: | if ndim != 0 and ndim != 1: | ||||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | raise ValueError("ndim != 1 or 0, get : %d" % ndim) | ||||
| if not isinstance(x, (Tensor, SymbolVar)): | if not isinstance(x, (Tensor, SymbolVar)): | ||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||||
| x = Const(x, dtype, device, reference) | |||||
| return x | return x | ||||
| if not isinstance(x, collections.abc.Sequence): | if not isinstance(x, collections.abc.Sequence): | ||||
| @@ -166,7 +166,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| if dtype is not None: | if dtype is not None: | ||||
| x = astype(x, dtype) | x = astype(x, dtype) | ||||
| return x | return x | ||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | |||||
| x = Const(x, dtype, device, reference) | |||||
| return x | return x | ||||
| @@ -337,7 +337,7 @@ def interpret_subgraph(func, dtype, device): | |||||
| return results | return results | ||||
| def apply_const(value, dtype=dtype, device=device): | def apply_const(value, dtype=dtype, device=device): | ||||
| return Const(value, dtype=dtype, device=device)()[0] | |||||
| return Const(value, dtype, device, None) | |||||
| outputs, outputs_has_grad = func(args, apply_expr, apply_const) | outputs, outputs_has_grad = func(args, apply_expr, apply_const) | ||||
| outputs = [ | outputs = [ | ||||
| @@ -10,10 +10,9 @@ import collections | |||||
| import math | import math | ||||
| from typing import Iterable, Optional, Sequence, Tuple, Union | from typing import Iterable, Optional, Sequence, Tuple, Union | ||||
| from ..core._imperative_rt.core2 import apply, dtype_promotion | |||||
| from ..core._imperative_rt.core2 import Const, apply, dtype_promotion | |||||
| from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.special import Const | |||||
| from ..core.tensor.array_method import _matmul | from ..core.tensor.array_method import _matmul | ||||
| from ..core.tensor.utils import _normalize_axis | from ..core.tensor.utils import _normalize_axis | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| @@ -729,7 +728,7 @@ def topk( | |||||
| op = builtin.TopK(mode=mode) | op = builtin.TopK(mode=mode) | ||||
| if not isinstance(k, Tensor): | if not isinstance(k, Tensor): | ||||
| (k,) = Const(k, dtype="int32", device=inp.device)() | |||||
| k = Const(k, "int32", inp.device, None) | |||||
| if len(inp.shape) == 1: | if len(inp.shape) == 1: | ||||
| if kth_only: | if kth_only: | ||||
| @@ -11,7 +11,7 @@ from functools import lru_cache | |||||
| from typing import NamedTuple, Optional, Sequence, Tuple, Union | from typing import NamedTuple, Optional, Sequence, Tuple, Union | ||||
| from ..core import _config | from ..core import _config | ||||
| from ..core._imperative_rt.core2 import apply, dtype_promotion | |||||
| from ..core._imperative_rt.core2 import Const, apply, dtype_promotion | |||||
| from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | ||||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| @@ -26,7 +26,6 @@ from ..core.ops.builtin import ( | |||||
| Reshape, | Reshape, | ||||
| TypeCvt, | TypeCvt, | ||||
| ) | ) | ||||
| from ..core.ops.special import Const | |||||
| from ..core.tensor import amp, megbrain_graph | from ..core.tensor import amp, megbrain_graph | ||||
| from ..core.tensor.array_method import _elwise_apply | from ..core.tensor.array_method import _elwise_apply | ||||
| from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
| @@ -1317,7 +1316,7 @@ def batch_norm( | |||||
| raise ValueError("Invalid param_dim {}".format(param_dim)) | raise ValueError("Invalid param_dim {}".format(param_dim)) | ||||
| if x is None: | if x is None: | ||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||||
| x = Const(value, inp.dtype, inp.device, None) | |||||
| shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) | ||||
| (result,) = apply(builtin.Broadcast(), x, shape) | (result,) = apply(builtin.Broadcast(), x, shape) | ||||
| return result | return result | ||||
| @@ -1541,7 +1540,7 @@ def sync_batch_norm( | |||||
| def _make_full_if_none(x, value): | def _make_full_if_none(x, value): | ||||
| if x is None: | if x is None: | ||||
| (x,) = Const(value, dtype=inp.dtype, device=_device)() | |||||
| x = Const(value, inp.dtype, _device, None) | |||||
| (result,) = apply(builtin.Broadcast(), x, reduce_shape) | (result,) = apply(builtin.Broadcast(), x, reduce_shape) | ||||
| return result | return result | ||||
| elif x.ndim == 1: | elif x.ndim == 1: | ||||
| @@ -13,6 +13,7 @@ import numpy as np | |||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
| Const, | |||||
| SymbolVar, | SymbolVar, | ||||
| apply, | apply, | ||||
| broadcast_cpp, | broadcast_cpp, | ||||
| @@ -24,7 +25,6 @@ from ..core._imperative_rt.core2 import ( | |||||
| from ..core._wrap import as_device | from ..core._wrap import as_device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import Copy, Identity | from ..core.ops.builtin import Copy, Identity | ||||
| from ..core.ops.special import Const | |||||
| from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | ||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| @@ -177,7 +177,7 @@ def full( | |||||
| shape = (shape,) | shape = (shape,) | ||||
| if device is None: | if device is None: | ||||
| device = get_default_device() | device = get_default_device() | ||||
| (x,) = Const(value, dtype=dtype, device=device)() | |||||
| x = Const(value, dtype, device, None) | |||||
| if type(shape) in (list, tuple) and len(shape) == 0: | if type(shape) in (list, tuple) and len(shape) == 0: | ||||
| return x | return x | ||||
| return broadcast_to(x, shape) | return broadcast_to(x, shape) | ||||
| @@ -325,7 +325,7 @@ def full_like( | |||||
| [2 2 2]] | [2 2 2]] | ||||
| """ | """ | ||||
| (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||||
| x = Const(value, inp.dtype, inp.device, inp) | |||||
| if inp.ndim == 0: | if inp.ndim == 0: | ||||
| return x | return x | ||||
| return broadcast_to(x, inp.shape) | return broadcast_to(x, inp.shape) | ||||
| @@ -1,4 +1,4 @@ | |||||
| from ..core.ops.special import Const | |||||
| from ..core._imperative_rt.core2 import Const | |||||
| from ..jit.tracing import is_tracing | from ..jit.tracing import is_tracing | ||||
| small_tensor_cache = {} | small_tensor_cache = {} | ||||
| @@ -7,11 +7,11 @@ small_tensor_cache = {} | |||||
| def _get_scalar_tensor_with_value(value, dtype=None, device=None): | def _get_scalar_tensor_with_value(value, dtype=None, device=None): | ||||
| global small_tensor_cache | global small_tensor_cache | ||||
| if is_tracing(): | if is_tracing(): | ||||
| (ret,) = Const(value, dtype=dtype, device=device)() | |||||
| ret = Const(value, dtype, device, None) | |||||
| else: | else: | ||||
| cache_key = (value, dtype, device) | cache_key = (value, dtype, device) | ||||
| if cache_key not in small_tensor_cache: | if cache_key not in small_tensor_cache: | ||||
| (ret,) = Const(value, dtype=dtype, device=device)() | |||||
| ret = Const(value, dtype, device, None) | |||||
| small_tensor_cache[cache_key] = ret | small_tensor_cache[cache_key] = ret | ||||
| else: | else: | ||||
| ret = small_tensor_cache[cache_key] | ret = small_tensor_cache[cache_key] | ||||
| @@ -16,6 +16,7 @@ from importlib import import_module | |||||
| from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union | ||||
| from ..core._imperative_rt import OpDef | from ..core._imperative_rt import OpDef | ||||
| from ..core._imperative_rt.core2 import Const | |||||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
| from ..core._imperative_rt.core2 import ( | from ..core._imperative_rt.core2 import ( | ||||
| apply, | apply, | ||||
| @@ -25,7 +26,6 @@ from ..core._imperative_rt.core2 import ( | |||||
| unset_module_tracing, | unset_module_tracing, | ||||
| ) | ) | ||||
| from ..core.ops.builtin import FakeQuant | from ..core.ops.builtin import FakeQuant | ||||
| from ..core.ops.special import Const | |||||
| from ..module import Module | from ..module import Module | ||||
| from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
| from ..version import __version__ | from ..version import __version__ | ||||
| @@ -764,7 +764,7 @@ class Constant(Expr): | |||||
| def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
| if isinstance(self.value, RawTensor): | if isinstance(self.value, RawTensor): | ||||
| return Const(self.value.numpy())() | |||||
| return (Const(self.value.numpy(), None, None, None),) | |||||
| return (self.value,) | return (self.value,) | ||||
| def __repr__(self): | def __repr__(self): | ||||
| @@ -639,6 +639,7 @@ WRAP_FUNC_PY35(squeeze_cpp); | |||||
| WRAP_FUNC_PY35(transpose_cpp); | WRAP_FUNC_PY35(transpose_cpp); | ||||
| WRAP_FUNC_PY35(broadcast_cpp); | WRAP_FUNC_PY35(broadcast_cpp); | ||||
| WRAP_FUNC_PY35(reshape_cpp); | WRAP_FUNC_PY35(reshape_cpp); | ||||
| WRAP_FUNC_PY35(Const); | |||||
| #undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
| @@ -777,6 +778,7 @@ void init_tensor(py::module m) { | |||||
| MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | ||||
| MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | ||||
| MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | ||||
| MGE_PY_INTERFACE(Const, Const), | |||||
| {nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
| for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
| if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
| @@ -94,7 +94,7 @@ bool is_bool_dtype(PyObject* args) { | |||||
| } | } | ||||
| py::object _Const( | py::object _Const( | ||||
| py::handle value, py::handle dtype, py::handle device, py::handle ref) { | |||||
| py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { | |||||
| py::object val = py::reinterpret_borrow<py::object>(value); | py::object val = py::reinterpret_borrow<py::object>(value); | ||||
| if (PyArray_Check(value.ptr())) { | if (PyArray_Check(value.ptr())) { | ||||
| py::tuple strides = | py::tuple strides = | ||||
| @@ -107,21 +107,56 @@ py::object _Const( | |||||
| } | } | ||||
| if (need_squeeze) { | if (need_squeeze) { | ||||
| val = py::reinterpret_borrow<py::array>(value); | val = py::reinterpret_borrow<py::array>(value); | ||||
| py::object orig_shp = val.attr("shape"); | |||||
| val = val.attr("squeeze")(); | val = val.attr("squeeze")(); | ||||
| val = val.attr("reshape")(val.attr("shape")); | |||||
| val = val.attr("reshape")(orig_shp); | |||||
| } | } | ||||
| } | } | ||||
| py::object ref; | |||||
| if (py::isinstance<py::tuple>(ref_hdl)) { | |||||
| py::tuple tup = py::reinterpret_borrow<py::tuple>(ref_hdl); | |||||
| if (tup.size()) { | |||||
| ref = tup[0]; | |||||
| } else { | |||||
| ref = py::none(); | |||||
| } | |||||
| } else { | |||||
| ref = py::reinterpret_borrow<py::object>(ref_hdl); | |||||
| } | |||||
| if (py::isinstance<PySymbolVar>(ref)) { | if (py::isinstance<PySymbolVar>(ref)) { | ||||
| auto ref_var = ref.cast<PySymbolVar*>(); | auto ref_var = ref.cast<PySymbolVar*>(); | ||||
| auto* graph = ref_var->m_node->owner_graph(); | auto* graph = ref_var->m_node->owner_graph(); | ||||
| auto cn = device.cast<CompNode>(); | |||||
| CompNode cn; | |||||
| if (device.ptr() == Py_None) { | |||||
| cn = ref_var->m_node->comp_node(); | |||||
| } else { | |||||
| cn = device.cast<CompNode>(); | |||||
| } | |||||
| OperatorNodeConfig config(cn); | OperatorNodeConfig config(cn); | ||||
| auto hv = npy::np2tensor( | auto hv = npy::np2tensor( | ||||
| val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | ||||
| auto typeobj = ref.get_type(); | auto typeobj = ref.get_type(); | ||||
| return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | ||||
| } | } | ||||
| py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none()); | |||||
| py::object device_obj; | |||||
| if (device.ptr() == Py_None) { | |||||
| device_obj = py::cast(CompNode::load(get_default_device())); | |||||
| } else if (py::isinstance<py::str>(device)) { | |||||
| py::object dmap = | |||||
| getattr(py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type), | |||||
| "dmap_callback"); | |||||
| if (dmap.ptr() != Py_None) { | |||||
| device_obj = dmap(device); | |||||
| py::print(device_obj); | |||||
| } else { | |||||
| device_obj = py::cast(CompNode::load(device.cast<std::string>())); | |||||
| } | |||||
| } else if (py::isinstance<CompNode>(device)) { | |||||
| device_obj = py::reinterpret_borrow<py::object>(device); | |||||
| } else { | |||||
| device_obj = getattr(device, "_cn"); | |||||
| } | |||||
| py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | |||||
| return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | ||||
| } | } | ||||
| @@ -1107,4 +1142,14 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
| } | } | ||||
| PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| return _Const(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]), | |||||
| py::handle(args[3])) | |||||
| .release() | |||||
| .ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
| @@ -20,4 +20,6 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| } // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||