GitOrigin-RevId: c56c87c88b
tags/v1.2.0
| @@ -12,7 +12,10 @@ import numpy as np | |||
| from ..._imperative_rt import CompNode, DeviceTensorND | |||
| from ..._imperative_rt.imperative import ( | |||
| _drop, | |||
| _get_dev_tensor, | |||
| _swap_in, | |||
| _swap_out, | |||
| apply_op, | |||
| delete, | |||
| get_device, | |||
| @@ -63,6 +66,15 @@ class RawTensor(TensorBase): | |||
| def _dev_tensor(self): | |||
| return _get_dev_tensor(self._handle) | |||
| def _drop(self): | |||
| _drop(self._handle) | |||
| def _swap_in(self): | |||
| _swap_in(self._handle) | |||
| def _swap_out(self): | |||
| _swap_out(self._handle) | |||
| def __repr__(self): | |||
| return "{}({}, device='{}')".format( | |||
| type(self).__qualname__, repr(self.numpy()), self.device | |||
| @@ -53,6 +53,15 @@ class Tensor(TensorBase): | |||
| def numpy(self): | |||
| return self._data.numpy() | |||
| def _drop(self): | |||
| self._data._drop() | |||
| def _swap_in(self): | |||
| self._data._swap_in() | |||
| def _swap_out(self): | |||
| self._data._swap_out() | |||
| class ApplyContext: | |||
| __slots__ = ("inputs", "outputs", "key") | |||
| @@ -473,6 +473,15 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||
| def numpy(self): | |||
| return self.__wrapped__.numpy() | |||
| def _drop(self): | |||
| self.__wrapped__._drop() | |||
| def _swap_in(self): | |||
| self.__wrapped__._swap_in() | |||
| def _swap_out(self): | |||
| self.__wrapped__._swap_out() | |||
| class TensorWrapper(GenericTensorWrapper): | |||
| def __init__(self, data, dtype=None, device=None): | |||
| @@ -966,6 +966,15 @@ class CompiledTensorProxy(RawTensor): | |||
| self.__data = self.__info.data_reader.get_value() | |||
| return self.__data | |||
| def _drop(self): | |||
| return | |||
| def _swap_in(self): | |||
| return | |||
| def _swap_out(self): | |||
| return | |||
| def __del__(self): | |||
| if self.__info.shape_read and self.__shape is not None: | |||
| self.__info.shape_reader.drop_value() | |||
| @@ -1001,6 +1010,15 @@ class LazyEvalTensor(RawTensor): | |||
| ret = ret.squeeze() | |||
| return ret | |||
| def _drop(self): | |||
| return | |||
| def _swap_in(self): | |||
| return | |||
| def _swap_out(self): | |||
| return | |||
| def _dev_tensor(self): | |||
| raise RuntimeError("cannot access data during symbolic tracing") | |||
| @@ -1042,6 +1060,15 @@ class TraceMixin: | |||
| active_trace._require_data(self.__handle) | |||
| return super()._dev_tensor() | |||
| def _drop(self): | |||
| return | |||
| def _swap_in(self): | |||
| return | |||
| def _swap_out(self): | |||
| return | |||
| class TracedRawTensor(TraceMixin, RawTensor): | |||
| pass | |||
| @@ -68,6 +68,15 @@ void init_imperative_rt(py::module m) { | |||
| .def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||
| return self.del(handle); | |||
| }) | |||
| .def("_swap_in", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||
| self.swap_in(handle); | |||
| }) | |||
| .def("_swap_out", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||
| self.swap_out(handle); | |||
| }) | |||
| .def("_drop", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||
| self.drop(handle); | |||
| }) | |||
| .def("get_value", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||
| PyObject* optr = npy::ndarray_from_tensor(self.get_value(handle), npy::ShareType::TRY_SHARE); | |||
| return py::reinterpret_steal<py::object>(optr); | |||
| @@ -76,6 +85,8 @@ void init_imperative_rt(py::module m) { | |||
| .def("get_device", &Interpreter::Channel::get_device) | |||
| .def("get_shape", &Interpreter::Channel::get_shape) | |||
| .def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor) | |||
| .def("_set_swap_flag", &Interpreter::Channel::set_swap_flag) | |||
| .def("_set_drop_flag", &Interpreter::Channel::set_drop_flag) | |||
| .def("apply_op", &Interpreter::Channel::apply_op) | |||
| .def("config_async_level", &Interpreter::Channel::config_async_level) | |||
| .def("get_async_level", &Interpreter::Channel::get_async_level) | |||
| @@ -84,7 +95,7 @@ void init_imperative_rt(py::module m) { | |||
| std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel(); | |||
| m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast( | |||
| std::move(ch), py::return_value_policy::move, {}); | |||
| for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level"}) { | |||
| for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level", "_drop", "_swap_in", "_swap_out", "_set_drop_flag", "_set_swap_flag"}) { | |||
| m.attr(name) = m.attr("interpreter").attr(name); | |||
| } | |||
| @@ -0,0 +1,124 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import itertools | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| from megengine import Tensor | |||
| from megengine.core._imperative_rt.imperative import _set_drop_flag, _set_swap_flag | |||
| from megengine.module import Linear, Module | |||
| from megengine.optimizer import SGD | |||
| batch_size = 64 | |||
| data_shape = (batch_size, 2) | |||
| label_shape = (batch_size,) | |||
| def minibatch_generator(): | |||
| while True: | |||
| inp_data = np.zeros((batch_size, 2)) | |||
| label = np.zeros(batch_size, dtype=np.int32) | |||
| for i in range(batch_size): | |||
| # [x0, x1], sampled from U[-1, 1] | |||
| inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||
| label[i] = 0 if np.prod(inp_data[i]) < 0 else 1 | |||
| yield inp_data.astype(np.float32), label.astype(np.int32) | |||
| def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float: | |||
| """ Calculate precision for given data and prediction. | |||
| :type data: [[x, y], ...] | |||
| :param data: Input data | |||
| :type pred: [[x_pred, y_pred], ...] | |||
| :param pred: Network output data | |||
| """ | |||
| correct = 0 | |||
| assert len(data) == len(pred) | |||
| for inp_data, pred_output in zip(data, pred): | |||
| label = 0 if np.prod(inp_data) < 0 else 1 | |||
| pred_label = np.argmax(pred_output) | |||
| if pred_label == label: | |||
| correct += 1 | |||
| return float(correct) / len(data) | |||
| class XORNet(Module): | |||
| def __init__(self): | |||
| self.mid_layers = 14 | |||
| self.num_class = 2 | |||
| super().__init__() | |||
| self.fc0 = Linear(self.num_class, self.mid_layers, bias=True) | |||
| self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True) | |||
| self.fc2 = Linear(self.mid_layers, self.num_class, bias=True) | |||
| def forward(self, x): | |||
| y = self.fc0(x) | |||
| x._swap_out() | |||
| x = F.tanh(y) | |||
| y = self.fc1(x) | |||
| x = F.tanh(y) | |||
| x = self.fc2(x) | |||
| y = (x + x) / 2 # in order to test drop() | |||
| y._drop() | |||
| return y | |||
| def test_training_converge_with_swap_and_drop(): | |||
| _set_swap_flag(True) | |||
| _set_drop_flag(True) | |||
| net = XORNet() | |||
| opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| def train(data, label): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = F.nn.cross_entropy(pred, label) | |||
| gm.backward(loss) | |||
| return loss | |||
| def infer(data): | |||
| return net(data) | |||
| train_dataset = minibatch_generator() | |||
| losses = [] | |||
| for data, label in itertools.islice(train_dataset, 2000): | |||
| data = Tensor(data, dtype=np.float32) | |||
| label = Tensor(label, dtype=np.int32) | |||
| opt.clear_grad() | |||
| loss = train(data, label) | |||
| opt.step() | |||
| losses.append(loss.numpy()) | |||
| assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" | |||
| ngrid = 10 | |||
| x = np.linspace(-1.0, 1.0, ngrid) | |||
| xx, yy = np.meshgrid(x, x) | |||
| xx = xx.reshape((ngrid * ngrid, 1)) | |||
| yy = yy.reshape((ngrid * ngrid, 1)) | |||
| data = np.concatenate((xx, yy), axis=1).astype(np.float32) | |||
| pred = infer(Tensor(data)).numpy() | |||
| precision = calculate_precision(data, pred) | |||
| assert precision == 1.0, "Test precision must be high enough, get {}".format( | |||
| precision | |||
| ) | |||
| _set_swap_flag(False) | |||
| _set_drop_flag(False) | |||
| @@ -52,9 +52,37 @@ void ChannelImpl::del(void* handle) { | |||
| m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||
| } | |||
| void ChannelImpl::swap_in(void* handle) { | |||
| if (m_enable_evict & SWAP) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| m_worker.add_task(SwapIn{reinterpret_cast<TensorInfo*>(handle)}); | |||
| } | |||
| } | |||
| void ChannelImpl::swap_out(void* handle) { | |||
| if (m_enable_evict & SWAP) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| m_worker.add_task(SwapOut{reinterpret_cast<TensorInfo*>(handle)}); | |||
| } | |||
| } | |||
| void ChannelImpl::drop(void* handle) { | |||
| if (m_enable_evict & DROP) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| m_worker.add_task(Drop{reinterpret_cast<TensorInfo*>(handle)}); | |||
| } | |||
| } | |||
| SmallVector<void*> ChannelImpl::apply_op( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<void*>& inputs) { | |||
| for (auto i : inputs) { | |||
| mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||
| "invalid handle: %p", i); | |||
| } | |||
| SmallVector<TensorInfo*> input_infos; | |||
| input_infos.reserve(inputs.size()); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| @@ -75,7 +103,8 @@ SmallVector<void*> ChannelImpl::apply_op( | |||
| SmallVector<void*> outputs; | |||
| // FIXME: remove this check when op check is correct | |||
| bool validated_bkp = true; | |||
| for (auto&& desc : output_descs) { | |||
| for (size_t i = 0;i < output_descs.size();i ++) { | |||
| auto&& desc = output_descs[i]; | |||
| if (desc.layout.ndim == 0) { | |||
| validated_bkp = false; | |||
| } | |||
| @@ -85,6 +114,18 @@ SmallVector<void*> ChannelImpl::apply_op( | |||
| cmd.outputs.push_back(info); | |||
| outputs.push_back(info); | |||
| } | |||
| if (m_enable_evict & DROP) { | |||
| for (auto out : cmd.outputs) { | |||
| out->path.op = cmd.op; | |||
| for (auto out_ : cmd.outputs) { | |||
| out->path.outputs.push_back(m_st.at(out_)); | |||
| } | |||
| for (auto inp : cmd.inputs) { | |||
| out->path.inputs.push_back(m_st.at(inp)); | |||
| inp->path.dep_outputs.push_back(m_st.at(out)); | |||
| } | |||
| } | |||
| } | |||
| m_worker.add_task(std::move(cmd)); | |||
| if (!(validated && validated_bkp) && m_async_level == 1) { | |||
| sync(); | |||
| @@ -192,11 +233,18 @@ int ChannelImpl::get_async_level() { | |||
| TensorInfo* ChannelImpl::alloc() { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| return m_pool.alloc(); | |||
| auto info = m_pool.alloc(); | |||
| m_st.insert(info); | |||
| return info; | |||
| } | |||
| void ChannelImpl::free(TensorInfo* ptr) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| if (ptr->path.dep_outputs.size() > 0) { | |||
| remove_dep(ptr); | |||
| } | |||
| m_st.erase(ptr); | |||
| mgb_assert(ptr->allow_delete, "delete before ref_cnt = 0"); | |||
| m_pool.free(ptr); | |||
| } | |||
| @@ -204,15 +252,136 @@ ChannelImpl::~ChannelImpl() { | |||
| close(); | |||
| } | |||
| void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| dest->desc.layout = ptr->layout(); | |||
| dest->desc.comp_node = ptr->comp_node(); | |||
| dest->ptr = std::move(ptr); | |||
| if (m_waitee == dest) { | |||
| m_cv.notify_all(); | |||
| void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { | |||
| if (notice) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| dest->desc.layout = ptr->layout(); | |||
| dest->desc.comp_node = ptr->comp_node(); | |||
| dest->ptr = std::move(ptr); | |||
| if (m_waitee == dest) { | |||
| m_cv.notify_all(); | |||
| } | |||
| } else { | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| dest->desc.layout = ptr->layout(); | |||
| dest->desc.comp_node = ptr->comp_node(); | |||
| dest->ptr = std::move(ptr); | |||
| } | |||
| } | |||
| void ChannelImpl::do_swap_out(TensorInfo* dest) { | |||
| if (dest->evict_type == DROP) { | |||
| mgb_log_warn("the evict type of tensor %p was set to DROP, this SWAP operation will be ignored", dest); | |||
| return; | |||
| } | |||
| if (!dest->ptr) { | |||
| return; | |||
| } | |||
| dest->evict_type = SWAP; | |||
| dest->value_fetched = false; | |||
| // TODO: swap in parallel | |||
| dest->h_value.copy_from(dest->ptr->dev_tensor()).sync(); | |||
| dest->ptr.reset(); | |||
| } | |||
| void ChannelImpl::do_swap_in(TensorInfo* dest) { | |||
| if (dest->ptr) { | |||
| return; | |||
| } | |||
| if (dest->h_value.empty()) { | |||
| mgb_log_error("backup of the tensor %p not found", dest); | |||
| return; | |||
| } | |||
| produce_tensor(dest, Tensor::make(dest->h_value), false); | |||
| dest->evict_type = NONE; | |||
| } | |||
| void ChannelImpl::remove_dep(TensorInfo* dest) { | |||
| for (auto i : dest->path.dep_outputs) { | |||
| auto out_ptr = i.lock(); | |||
| if (out_ptr) { | |||
| regenerate(out_ptr.get(), true); | |||
| } | |||
| } | |||
| } | |||
| void ChannelImpl::do_drop(TensorInfo* dest) { | |||
| if (dest->evict_type == SWAP) { | |||
| mgb_log_warn("the evict type of tensor %p was set to SWAP, this DROP operation will be ignored", dest); | |||
| return; | |||
| } | |||
| if (!dest->path.op) { | |||
| mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", dest); | |||
| return; | |||
| } | |||
| if (dest->recompute_times >= m_max_recompute_time) { | |||
| mgb_log_warn("the recomputation time for tensor %p exceeds the limit, this drop operation will be ignored", dest); | |||
| return; | |||
| } | |||
| if (!dest->ptr) { | |||
| return; | |||
| } | |||
| dest->evict_type = DROP; | |||
| dest->value_fetched = false; | |||
| dest->ptr.reset(); | |||
| } | |||
| void ChannelImpl::set_swap_flag(bool flag) { | |||
| if (flag) { | |||
| m_enable_evict |= SWAP; | |||
| } else { | |||
| m_enable_evict &= ~SWAP; | |||
| } | |||
| } | |||
| void ChannelImpl::set_drop_flag(bool flag) { | |||
| if (flag) { | |||
| m_enable_evict |= DROP; | |||
| } else { | |||
| m_enable_evict &= ~DROP; | |||
| } | |||
| } | |||
| void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { | |||
| if (!info->ptr && info->evict_type != NONE) { | |||
| if (info->evict_type == SWAP) { | |||
| do_swap_in(info); | |||
| } else { | |||
| mgb_assert(info->evict_type == DROP); | |||
| mgb_assert(info->path.op, "recomputation path not found"); | |||
| auto path = info->path; | |||
| SmallVector<TensorPtr> inputs; | |||
| inputs.reserve(path.inputs.size()); | |||
| for (auto i : path.inputs) { | |||
| mgb_assert(i, "invalid history input"); | |||
| if (!i->ptr) { | |||
| regenerate(i.get(), must_drop); | |||
| } | |||
| inputs.push_back(i->ptr); | |||
| } | |||
| auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs); | |||
| for (size_t i = 0; i < outputs.size(); i ++) { | |||
| auto out_ptr = path.outputs[i].lock(); | |||
| if (out_ptr) { | |||
| out_ptr->recompute_times ++; | |||
| if (!out_ptr->ptr && out_ptr->evict_type == DROP) { | |||
| produce_tensor(out_ptr.get(), std::move(outputs[i]), false); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (must_drop) { | |||
| if (info->path.op) { | |||
| info->path.op.reset(); | |||
| info->path.inputs.clear(); | |||
| if (info->evict_type == DROP) { | |||
| info->evict_type = NONE; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -227,6 +396,11 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| SmallVector<TensorPtr> tensor_inputs; | |||
| tensor_inputs.reserve(cmd.inputs.size()); | |||
| for (auto i : cmd.inputs) { | |||
| if (m_enable_evict && i->evict_type != NONE) { | |||
| if (!i->ptr) { | |||
| regenerate(i); | |||
| } | |||
| } | |||
| mgb_assert(i->ptr, "Invalid input tensor ptr!"); | |||
| tensor_inputs.push_back(i->ptr); | |||
| } | |||
| @@ -238,6 +412,11 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| } else if constexpr (std::is_same_v<T, Del>) { | |||
| free(cmd.dest); | |||
| } else if constexpr (std::is_same_v<T, GetValue>) { | |||
| if (m_enable_evict && cmd.dest->evict_type != NONE) { | |||
| if (!cmd.dest->ptr) { | |||
| regenerate(cmd.dest); | |||
| } | |||
| } | |||
| mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); | |||
| cmd.dest->ptr->fetch_value(); | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| @@ -245,6 +424,12 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| if (m_waitee == cmd.dest) { | |||
| m_cv.notify_all(); | |||
| } | |||
| } else if constexpr (std::is_same_v<T, SwapIn>) { | |||
| do_swap_in(cmd.dest); | |||
| } else if constexpr (std::is_same_v<T, SwapOut>) { | |||
| do_swap_out(cmd.dest); | |||
| } else if constexpr (std::is_same_v<T, Drop>) { | |||
| do_drop(cmd.dest); | |||
| } else { | |||
| static_assert(!std::is_same_v<T, T>); | |||
| } | |||
| @@ -24,11 +24,34 @@ struct InterpreterImpl : Interpreter { | |||
| std::unique_ptr<Channel> create_channel() override; | |||
| }; | |||
| enum EvictType { | |||
| NONE = 0, | |||
| SWAP = 1, | |||
| DROP = 2, | |||
| }; | |||
| struct TensorInfo; | |||
| using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||
| struct TensorInfo { | |||
| TensorPtr ptr; | |||
| LogicalTensorDesc desc; | |||
| bool value_fetched = false; | |||
| bool invalid = false; | |||
| bool allow_delete = false; | |||
| EvictType evict_type = NONE; | |||
| HostTensorND h_value; | |||
| size_t locked = 0; | |||
| size_t recompute_times = 0; | |||
| struct ComputePath { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<TensorInfoPtr> inputs; | |||
| SmallVector<std::weak_ptr<TensorInfo>> outputs; | |||
| SmallVector<std::weak_ptr<TensorInfo>> dep_outputs; | |||
| } path; | |||
| }; | |||
| struct Put { | |||
| @@ -46,10 +69,24 @@ struct Del { | |||
| struct GetValue { | |||
| TensorInfo* dest; | |||
| }; | |||
| struct SwapIn { | |||
| TensorInfo* dest; | |||
| }; | |||
| struct SwapOut { | |||
| TensorInfo* dest; | |||
| }; | |||
| struct Drop { | |||
| TensorInfo* dest; | |||
| }; | |||
| using Command = std::variant<Put, | |||
| ApplyOp, | |||
| Del, | |||
| GetValue>; | |||
| GetValue, | |||
| SwapIn, | |||
| SwapOut, | |||
| Drop>; | |||
| struct ChannelImpl : Interpreter::Channel { | |||
| ChannelImpl() : m_worker(this) {} | |||
| @@ -59,6 +96,9 @@ struct ChannelImpl : Interpreter::Channel { | |||
| Handle put(const DeviceTensorND& value) override; | |||
| void del(Handle) override; | |||
| void swap_in(Handle) override; | |||
| void swap_out(Handle) override; | |||
| void drop(Handle) override; | |||
| SmallVector<Handle> apply_op( | |||
| std::shared_ptr<OpDef> op, | |||
| @@ -73,6 +113,8 @@ struct ChannelImpl : Interpreter::Channel { | |||
| void sync() override; | |||
| void close() override; | |||
| void set_swap_flag(bool) override; | |||
| void set_drop_flag(bool) override; | |||
| void config_async_level(int level) override; | |||
| int get_async_level() override; | |||
| @@ -80,12 +122,17 @@ struct ChannelImpl : Interpreter::Channel { | |||
| private: | |||
| TensorInfo* alloc(); | |||
| void free(TensorInfo*); | |||
| void remove_dep(TensorInfo*); | |||
| void process_one_task(Command&); | |||
| void check_worker_exc_unsafe(); | |||
| void produce_tensor(TensorInfo* dest, TensorPtr ptr); | |||
| void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice); | |||
| void do_swap_out(TensorInfo* dest); | |||
| void do_swap_in(TensorInfo* dest); | |||
| void do_drop(TensorInfo* dest); | |||
| void regenerate(TensorInfo* dest, bool must_drop); | |||
| std::mutex m_mutex; | |||
| std::condition_variable m_cv; | |||
| @@ -93,6 +140,7 @@ private: | |||
| std::unordered_set<Handle> m_valid_handle; | |||
| TensorInfo* m_waitee = nullptr; | |||
| std::exception_ptr m_worker_exc; | |||
| size_t m_enable_evict = 0; | |||
| struct WorkQueue : AsyncQueueSC<Command, WorkQueue> { | |||
| WorkQueue(ChannelImpl* owner) : m_owner(owner) {} | |||
| @@ -103,11 +151,30 @@ private: | |||
| ChannelImpl* m_owner; | |||
| } m_worker; | |||
| struct SharedTensorInfoMap { | |||
| void insert(TensorInfo* info) { | |||
| MGB_LOCK_GUARD(mtx); | |||
| tmap.emplace(info, TensorInfoPtr{info, [](TensorInfo* ptr){ ptr->allow_delete = true;}}); | |||
| } | |||
| void erase(TensorInfo* info) { | |||
| MGB_LOCK_GUARD(mtx); | |||
| tmap.erase(info); | |||
| } | |||
| TensorInfoPtr at(TensorInfo* info) { | |||
| MGB_LOCK_GUARD(mtx); | |||
| return tmap.at(info); | |||
| } | |||
| private: | |||
| std::mutex mtx; | |||
| std::unordered_map<TensorInfo*, TensorInfoPtr> tmap; | |||
| }m_st; | |||
| //! config whether raise error exactly when invoking op. | |||
| //! level 2: both device and user side errors are async; | |||
| //! level 1: user side errors are sync; | |||
| //! level 0: both sync. | |||
| int m_async_level = 2; | |||
| int m_max_recompute_time = 1; | |||
| }; | |||
| } // namespace mgb::imperative::interpreter::intl | |||
| @@ -25,6 +25,9 @@ struct Interpreter { | |||
| virtual Handle put(const DeviceTensorND& value) = 0; | |||
| virtual void del(Handle) = 0; | |||
| virtual void swap_in(Handle) = 0; | |||
| virtual void swap_out(Handle) = 0; | |||
| virtual void drop(Handle) = 0; | |||
| virtual SmallVector<Handle> apply_op( | |||
| std::shared_ptr<OpDef> op, | |||
| @@ -39,6 +42,8 @@ struct Interpreter { | |||
| virtual void sync() = 0; | |||
| virtual void close() = 0; | |||
| virtual void set_swap_flag(bool) = 0; | |||
| virtual void set_drop_flag(bool) = 0; | |||
| virtual void config_async_level(int level) = 0; | |||
| virtual int get_async_level() = 0; | |||