* remove dispatcher/interpreter python wrapper
* rename tensor_wrapper to array_method
GitOrigin-RevId: b8a402c2be
tags/v1.2.0
| @@ -18,7 +18,7 @@ from ..core._imperative_rt.core2 import apply | |||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis | |||||
| from ..core.tensor.array_method import _broadcast, _remove_axis | |||||
| from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
| astensor1d, | astensor1d, | ||||
| convert_inputs, | convert_inputs, | ||||
| @@ -18,7 +18,7 @@ import weakref | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt import GraphProfiler, common, put | |||||
| from ..core._imperative_rt import GraphProfiler, common | |||||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | from ..core._imperative_rt.core2 import Tensor as RawTensor | ||||
| from ..core._imperative_rt.core2 import TensorWeakRef | from ..core._imperative_rt.core2 import TensorWeakRef | ||||
| from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor | from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor | ||||
| @@ -18,7 +18,7 @@ from .core._imperative_rt.core2 import apply | |||||
| from .core._trace_option import use_symbolic_shape | from .core._trace_option import use_symbolic_shape | ||||
| from .core._wrap import device as as_device | from .core._wrap import device as as_device | ||||
| from .core.ops.builtin import Copy, GetVarShape | from .core.ops.builtin import Copy, GetVarShape | ||||
| from .core.tensor.tensor_wrapper import ArrayMethodMixin | |||||
| from .core.tensor.array_method import ArrayMethodMixin | |||||
| from .device import _valid_device, get_default_device | from .device import _valid_device, get_default_device | ||||
| from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
| @@ -42,7 +42,6 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
| else: | else: | ||||
| cn = device._cn | cn = device._cn | ||||
| # import pdb; pdb.set_trace() | |||||
| if isinstance(data, _Tensor): | if isinstance(data, _Tensor): | ||||
| obj = _Tensor.__new__(cls, data) | obj = _Tensor.__new__(cls, data) | ||||
| else: | else: | ||||
| @@ -14,7 +14,7 @@ from typing import Iterable, List, Optional | |||||
| from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry | from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry | ||||
| from ..core._imperative_rt import ProfilerImpl as _Profiler | from ..core._imperative_rt import ProfilerImpl as _Profiler | ||||
| from ..core._imperative_rt.imperative import sync | |||||
| from ..core._imperative_rt.core2 import sync | |||||
| from ..core._imperative_rt.ops import CollectiveComm | from ..core._imperative_rt.ops import CollectiveComm | ||||
| @@ -1,5 +1,5 @@ | |||||
| from ..core._imperative_rt import TensorSanityCheckImpl | from ..core._imperative_rt import TensorSanityCheckImpl | ||||
| from ..core._imperative_rt.imperative import sync | |||||
| from ..core._imperative_rt.core2 import sync | |||||
| class TensorSanityCheck: | class TensorSanityCheck: | ||||
| @@ -1,229 +0,0 @@ | |||||
| /** | |||||
| * \file imperative/python/src/dispatcher.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./dispatcher.h" | |||||
| #include "./pyext17.h" | |||||
| #include "megbrain/exception.h" | |||||
| #include "megbrain/utils/hash.h" | |||||
| #include "megbrain/utils/small_vector.h" | |||||
| #include <unordered_map> | |||||
| #include <structmember.h> | |||||
| namespace py = pybind11; | |||||
| namespace pyx = pyext17; | |||||
| namespace { | |||||
| struct Handler { | |||||
| PyObject* func; // borrowed | |||||
| bool enabled; | |||||
| Handler() = default; | |||||
| Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {} | |||||
| }; | |||||
| using FastSig = mgb::SmallVector<void*, 8>; | |||||
| using MRO = std::vector<Handler*>; | |||||
| struct Frame { | |||||
| MRO* mro; | |||||
| size_t mro_offset; | |||||
| Frame() = default; | |||||
| Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {} | |||||
| }; | |||||
| struct FastSigHash { | |||||
| size_t operator()(const FastSig& sig) const { | |||||
| auto* ptr = &sig.front(); | |||||
| return mgb::XXHash() | |||||
| .update(ptr, sig.size() * sizeof(FastSig::value_type)) | |||||
| .digest(); | |||||
| } | |||||
| }; | |||||
| struct ObjectIdHash : std::hash<void*> { | |||||
| size_t operator()(const py::handle& h) const { | |||||
| return std::hash<void*>::operator()(h.ptr()); | |||||
| } | |||||
| }; | |||||
| namespace { | |||||
| using Container = std::vector<Frame>; | |||||
| struct DispatcherStack: Container { | |||||
| constexpr static size_t MAX_RECURSIVE_DEPTH = 1024u; | |||||
| DispatcherStack() { reserve(MAX_RECURSIVE_DEPTH); } | |||||
| template<typename... Args> | |||||
| auto&& emplace_back_safely(Args&& ...args) { | |||||
| mgb_throw_if(size() >= MAX_RECURSIVE_DEPTH, mgb::MegBrainError, | |||||
| "recursion depth %zu is greater than the MAX_RECURSIVE_DEPTH(%zu)", | |||||
| size(), MAX_RECURSIVE_DEPTH); | |||||
| return emplace_back(std::forward<Args>(args)...); | |||||
| } | |||||
| }; | |||||
| } // anonymous namespace | |||||
| struct Dispatcher { | |||||
| std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache; | |||||
| DispatcherStack stack; | |||||
| std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry; | |||||
| inline py::handle self() { | |||||
| return pyx::wrap<Dispatcher>::pycast(this); | |||||
| } | |||||
| bool prepare_call(PyObject*const* args, Py_ssize_t nargs) { | |||||
| FastSig sig(nargs); | |||||
| for (Py_ssize_t i = 0; i < nargs; ++i) { | |||||
| sig[i] = Py_TYPE(args[i]); | |||||
| } | |||||
| auto it = cache.find(sig); | |||||
| if (it == cache.end()) { | |||||
| if (auto mro = resolve(sig)) { | |||||
| it = cache.emplace(std::move(sig), std::move(mro)).first; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| stack.emplace_back_safely(it->second.get()); | |||||
| return true; | |||||
| } | |||||
| template<typename T> | |||||
| PyObject* do_call(T&& caller) { | |||||
| auto& frame = stack.back(); | |||||
| auto& mro = *frame.mro; | |||||
| auto& i = frame.mro_offset; | |||||
| if (!mro.size()) { | |||||
| PyErr_SetString(PyExc_NotImplementedError, "function not registered in dispatcher"); | |||||
| return nullptr; | |||||
| } | |||||
| for (; i < mro.size(); ++i) { | |||||
| if (mro[i]->enabled) { | |||||
| auto ret = caller(mro[i]->func); | |||||
| if (ret != Py_NotImplemented) { | |||||
| stack.pop_back(); | |||||
| return ret; | |||||
| } | |||||
| Py_DECREF(ret); | |||||
| } | |||||
| } | |||||
| PyErr_SetString(PyExc_NotImplementedError, "mro exhausted"); | |||||
| stack.pop_back(); | |||||
| return nullptr; | |||||
| } | |||||
| std::unique_ptr<MRO> resolve(const FastSig& sig) { | |||||
| try { | |||||
| py::tuple args(sig.size()); | |||||
| for (size_t i = 0; i < sig.size(); ++i) { | |||||
| args[i] = (PyObject*)sig[i]; | |||||
| } | |||||
| auto mro_iter = self().attr("dispatch_iter")(*args); | |||||
| auto ret = std::make_unique<MRO>(); | |||||
| for (auto i : mro_iter) { | |||||
| auto it = registry.find(py::reinterpret_borrow<py::object>(i)); | |||||
| if (it == registry.end()) { | |||||
| PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function"); | |||||
| return nullptr; | |||||
| } | |||||
| ret->push_back(it->second.get()); | |||||
| } | |||||
| return ret; | |||||
| } catch (py::error_already_set& e) { | |||||
| e.restore(); | |||||
| } catch (std::runtime_error& e) { | |||||
| PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| public: | |||||
| static constexpr auto tp_name = "Dispatcher"; | |||||
| PyObject* tp_call(PyObject* args, PyObject* kwargs) { | |||||
| if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr; | |||||
| return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);}); | |||||
| } | |||||
| #if PY_MINOR_VERSION >= 6 | |||||
| PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) { | |||||
| if (!prepare_call(args, nargs)) return nullptr; | |||||
| return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);}); | |||||
| } | |||||
| #endif | |||||
| #if PY_MINOR_VERSION >= 6 | |||||
| PyObject* super(PyObject*const* args, Py_ssize_t nargs) { | |||||
| if (stack.empty()) { | |||||
| PyErr_SetString(PyExc_RuntimeError, "super called at top level"); | |||||
| return nullptr; | |||||
| } | |||||
| stack.emplace_back_safely(stack.back()).mro_offset++; | |||||
| return do_call([=](PyObject* func){return _PyObject_FastCall(func, const_cast<PyObject**>(args), nargs);}); | |||||
| } | |||||
| #else | |||||
| PyObject* super(PyObject* args, PyObject* kwargs) { | |||||
| if (stack.empty()) { | |||||
| PyErr_SetString(PyExc_RuntimeError, "super called at top level"); | |||||
| return nullptr; | |||||
| } | |||||
| stack.emplace_back_safely(stack.back()).mro_offset++; | |||||
| return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);}); | |||||
| } | |||||
| #endif | |||||
| void enable(PyObject* func) { | |||||
| auto obj = py::reinterpret_borrow<py::object>(func); | |||||
| auto it = registry.find(obj); | |||||
| if (it != registry.end()) { | |||||
| it->second->enabled = true; | |||||
| } else { | |||||
| registry.emplace(std::move(obj), std::make_unique<Handler>(func)); | |||||
| } | |||||
| } | |||||
| PyObject* disable(PyObject* func) { | |||||
| auto obj = py::reinterpret_borrow<py::object>(func); | |||||
| auto it = registry.find(obj); | |||||
| if (it == registry.end()) { | |||||
| PyErr_SetString(PyExc_ValueError, "function not registered"); | |||||
| return nullptr; | |||||
| } else { | |||||
| it->second->enabled = false; | |||||
| } | |||||
| Py_RETURN_NONE; | |||||
| } | |||||
| void clear_cache() { | |||||
| cache.clear(); | |||||
| } | |||||
| }; | |||||
| } // namespace | |||||
| void init_dispatcher(py::module m) { | |||||
| auto* dispatcher_type = pyx::wrap<Dispatcher>::type() | |||||
| .def<&Dispatcher::enable>("enable") | |||||
| .def<&Dispatcher::disable>("disable") | |||||
| .def<&Dispatcher::clear_cache>("clear_cache") | |||||
| #if PY_MINOR_VERSION >= 6 | |||||
| .def<&Dispatcher::tp_vectorcall>("call") | |||||
| #else | |||||
| .def<&Dispatcher::tp_call>("call") | |||||
| #endif | |||||
| .def<&Dispatcher::super>("super") | |||||
| .finalize(); | |||||
| if (!dispatcher_type) throw py::error_already_set(); | |||||
| m.attr("Dispatcher") = dispatcher_type; | |||||
| } | |||||
| @@ -1,16 +0,0 @@ | |||||
| /** | |||||
| * \file imperative/python/src/dispatcher.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <pybind11/pybind11.h> | |||||
| void init_dispatcher(pybind11::module); | |||||
| @@ -51,59 +51,5 @@ make_backward_graph( | |||||
| } // namespace | } // namespace | ||||
| void init_imperative_rt(py::module m) { | void init_imperative_rt(py::module m) { | ||||
| py::class_<Interpreter::Channel>(m, "Interpreter") | |||||
| .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { | |||||
| if (!cn.valid()) { | |||||
| cn = CompNode::load(get_default_device()); | |||||
| } | |||||
| constexpr int size_threshhold = TensorShape::MAX_NDIM; | |||||
| if (data.size() > size_threshhold) { | |||||
| return self.put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||||
| } else { | |||||
| HostTensorND ret(cn); | |||||
| return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||||
| } | |||||
| }, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none()) | |||||
| .def("put", py::overload_cast<const DeviceTensorND&>(&Interpreter::Channel::put)) | |||||
| .def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||||
| return self.del(handle); | |||||
| }) | |||||
| .def("_swap_in", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||||
| self.swap_in(handle); | |||||
| }) | |||||
| .def("_swap_out", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||||
| self.swap_out(handle); | |||||
| }) | |||||
| .def("_drop", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||||
| self.drop(handle); | |||||
| }) | |||||
| .def("get_value", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||||
| PyObject* optr = npy::ndarray_from_tensor(self.get_value(handle), npy::ShareType::TRY_SHARE); | |||||
| return py::reinterpret_steal<py::object>(optr); | |||||
| }) | |||||
| .def("get_dtype", &Interpreter::Channel::get_dtype) | |||||
| .def("get_device", &Interpreter::Channel::get_device) | |||||
| .def("get_shape", &Interpreter::Channel::get_shape) | |||||
| .def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor) | |||||
| .def("_set_swap_flag", &Interpreter::Channel::set_swap_flag) | |||||
| .def("_set_drop_flag", &Interpreter::Channel::set_drop_flag) | |||||
| .def("apply_op", &Interpreter::Channel::apply_op) | |||||
| .def("config_async_level", &Interpreter::Channel::config_async_level) | |||||
| .def("get_async_level", &Interpreter::Channel::get_async_level) | |||||
| .def("sync", &Interpreter::Channel::sync, py::call_guard<py::gil_scoped_release>()); | |||||
| std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel(); | |||||
| m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast( | |||||
| std::move(ch), py::return_value_policy::move, {}); | |||||
| for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level", "_drop", "_swap_in", "_swap_out", "_set_drop_flag", "_set_swap_flag"}) { | |||||
| m.attr(name) = m.attr("interpreter").attr(name); | |||||
| } | |||||
| m.def("sync", [m]() { | |||||
| m.attr("interpreter").attr("sync")(); | |||||
| py::gil_scoped_release _; | |||||
| py_task_q.wait_all_task_finish(); | |||||
| }); | |||||
| m.def("make_backward_graph", &make_backward_graph); | m.def("make_backward_graph", &make_backward_graph); | ||||
| } | } | ||||
| @@ -21,8 +21,6 @@ | |||||
| #include "./graph_rt.h" | #include "./graph_rt.h" | ||||
| #include "./ops.h" | #include "./ops.h" | ||||
| #include "./dispatcher.h" | |||||
| #include "./tensor.h" | #include "./tensor.h" | ||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| @@ -70,7 +68,5 @@ PYBIND11_MODULE(MODULE_NAME, m) { | |||||
| )", | )", | ||||
| py::getattr(m, "__dict__")); | py::getattr(m, "__dict__")); | ||||
| init_dispatcher(submodule(m, "dispatcher")); | |||||
| init_tensor(submodule(m, "core2")); | init_tensor(submodule(m, "core2")); | ||||
| } | } | ||||
| @@ -16,7 +16,7 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine.core._imperative_rt import CompNode, TensorAttr, core2, imperative | |||||
| from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | |||||
| from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | ||||
| from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| @@ -54,10 +54,10 @@ def test_simple_arith(): | |||||
| def test_tensor_on_device(): | def test_tensor_on_device(): | ||||
| device = megengine.core._imperative_rt.CompNode("cpu0:1") | device = megengine.core._imperative_rt.CompNode("cpu0:1") | ||||
| x = np.random.rand(10).astype("float32") | x = np.random.rand(10).astype("float32") | ||||
| xx = megengine.core._imperative_rt.put(x, device=device) | |||||
| assert str(megengine.core._imperative_rt.get_device(xx)) == "cpu0:1" | |||||
| np.testing.assert_equal(x, megengine.core._imperative_rt.get_value(xx)) | |||||
| megengine.core._imperative_rt.delete(xx) | |||||
| xx = megengine.tensor(x, device=device) | |||||
| assert str(xx.device) == "cpu0:1" | |||||
| np.testing.assert_equal(x, xx.numpy()) | |||||
| del xx | |||||
| def test_raw_tensor(): | def test_raw_tensor(): | ||||