GitOrigin-RevId: 020d1e88d4
tags/v1.2.0
| @@ -33,7 +33,7 @@ def _run_wrapped( | |||||
| class launcher: | class launcher: | ||||
| """Decorator for launching multiple processes in single-machine multi-gpu training. | """Decorator for launching multiple processes in single-machine multi-gpu training. | ||||
| :param func: the function you want to launch in distributed mode. | :param func: the function you want to launch in distributed mode. | ||||
| :param n_gpus: how many devices each node. | :param n_gpus: how many devices each node. | ||||
| :param world_size: how many devices totally. | :param world_size: how many devices totally. | ||||
| @@ -32,7 +32,7 @@ namespace views = ranges::views; | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||||
| interpreter::Interpreter::Channel* interpreter_for_py; | |||||
| PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing, | PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing, | ||||
| *cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode; | *cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode; | ||||
| @@ -673,7 +673,9 @@ py::object make_empty_tensorwrapper() { | |||||
| } | } | ||||
| void init_tensor(py::module m) { | void init_tensor(py::module m) { | ||||
| interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||||
| imperative::Tensor::static_initialize(); | |||||
| static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||||
| interpreter_for_py = sl_interpreter_for_py.get(); | |||||
| auto* tensor_type = TensorWrapper::wrap_t::type() | auto* tensor_type = TensorWrapper::wrap_t::type() | ||||
| .def<&TensorWrapper::numpy>("numpy") | .def<&TensorWrapper::numpy>("numpy") | ||||
| @@ -724,6 +726,8 @@ void init_tensor(py::module m) { | |||||
| [](int level) { interpreter_for_py->config_async_level(level); }); | [](int level) { interpreter_for_py->config_async_level(level); }); | ||||
| m.def("get_async_level", | m.def("get_async_level", | ||||
| []() { return interpreter_for_py->get_async_level(); }); | []() { return interpreter_for_py->get_async_level(); }); | ||||
| m.def("set_buffer_length", | |||||
| [](int length) { interpreter_for_py->set_buffer_length(length); }); | |||||
| m.def("sync", | m.def("sync", | ||||
| []() { | []() { | ||||
| interpreter_for_py->sync(); | interpreter_for_py->sync(); | ||||
| @@ -34,7 +34,7 @@ struct ObjectPtr : B { | |||||
| namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
| extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||||
| extern interpreter::Interpreter::Channel* interpreter_for_py; | |||||
| class SharedHandle { | class SharedHandle { | ||||
| using Handle = interpreter::Interpreter::Handle; | using Handle = interpreter::Interpreter::Handle; | ||||
| @@ -111,6 +111,11 @@ void BlobManagerImpl::defrag(const CompNode& cn) { | |||||
| MGB_TRY{cn.free_device(cn.alloc_device(tot_sz));} | MGB_TRY{cn.free_device(cn.alloc_device(tot_sz));} | ||||
| MGB_CATCH(MemAllocError&, {}) | MGB_CATCH(MemAllocError&, {}) | ||||
| // sort blobs by created time, may be helpful for reduce memory fragment | |||||
| std::sort(blob_data_arrary.begin(), blob_data_arrary.end(), [](auto& lhs, auto& rhs){ | |||||
| return lhs.blob->id() < rhs.blob->id(); | |||||
| }); | |||||
| // allocate for each storage | // allocate for each storage | ||||
| for (auto i : blob_data_arrary) { | for (auto i : blob_data_arrary) { | ||||
| DeviceTensorStorage d_storage = DeviceTensorStorage(cn); | DeviceTensorStorage d_storage = DeviceTensorStorage(cn); | ||||
| @@ -22,10 +22,10 @@ class FunctionHooker; | |||||
| template <typename TRet, typename... TArgs> | template <typename TRet, typename... TArgs> | ||||
| class FunctionHooker<TRet(TArgs...)> { | class FunctionHooker<TRet(TArgs...)> { | ||||
| public: | public: | ||||
| using FunctionType = thin_function<TRet(TArgs&&...)>; | |||||
| using FunctionType = thin_function<TRet(TArgs...)>; | |||||
| //Type of hooks. Hook should accept a real function as argument | //Type of hooks. Hook should accept a real function as argument | ||||
| //and invoke it on an appropriate time | //and invoke it on an appropriate time | ||||
| using HookType = thin_function<TRet(FunctionType, TArgs&&...)>; | |||||
| using HookType = thin_function<TRet(FunctionType, TArgs...)>; | |||||
| explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { | explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { | ||||
| m_backup = {nullptr, [](FunctionType*){}}; | m_backup = {nullptr, [](FunctionType*){}}; | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ public: | |||||
| m_backup = decltype(m_backup)(backup, restorer); | m_backup = decltype(m_backup)(backup, restorer); | ||||
| } | } | ||||
| //Replace with hooked version | //Replace with hooked version | ||||
| *m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet { | |||||
| *m_fptr = [func = *m_fptr, hook](TArgs... args) -> TRet { | |||||
| return hook(func, std::forward<TArgs>(args)...); | return hook(func, std::forward<TArgs>(args)...); | ||||
| }; | }; | ||||
| //Convinent for chain call | //Convinent for chain call | ||||
| @@ -58,7 +58,7 @@ private: | |||||
| //Helps to deduce template args | //Helps to deduce template args | ||||
| template <typename TRet, typename... TArgs> | template <typename TRet, typename... TArgs> | ||||
| FunctionHooker(thin_function<TRet(TArgs...)>* f) | FunctionHooker(thin_function<TRet(TArgs...)>* f) | ||||
| ->FunctionHooker<TRet(TArgs...)>; | |||||
| -> FunctionHooker<TRet(TArgs...)>; | |||||
| template<typename TSignature> | template<typename TSignature> | ||||
| auto make_shared_hook(thin_function<TSignature>* fptr){ | auto make_shared_hook(thin_function<TSignature>* fptr){ | ||||
| @@ -11,20 +11,20 @@ | |||||
| #include "./interpreter_impl.h" | #include "./interpreter_impl.h" | ||||
| #include "megbrain/common.h" | #include "megbrain/common.h" | ||||
| #include "megbrain/imperative/opr_utility.h" | |||||
| #include "megbrain/imperative/ops/backward_graph.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace imperative; | using namespace imperative; | ||||
| using namespace interpreter; | using namespace interpreter; | ||||
| using namespace interpreter::intl; | using namespace interpreter::intl; | ||||
| std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() { | std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() { | ||||
| return std::make_unique<ChannelImpl>(); | return std::make_unique<ChannelImpl>(); | ||||
| } | } | ||||
| Interpreter& Interpreter::inst() { | Interpreter& Interpreter::inst() { | ||||
| Tensor::_static_init(); | |||||
| static InterpreterImpl inst_; | static InterpreterImpl inst_; | ||||
| return inst_; | return inst_; | ||||
| } | } | ||||
| @@ -35,7 +35,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
| info->desc.comp_node = value.comp_node(); | info->desc.comp_node = value.comp_node(); | ||||
| info->desc.value = value.proxy_to_default_cpu(); | info->desc.value = value.proxy_to_default_cpu(); | ||||
| m_valid_handle.insert(info); | m_valid_handle.insert(info); | ||||
| m_worker.add_task(Put{info, value, no_cache}); | |||||
| m_buffer.enqueue(Put{info, value, no_cache}); | |||||
| return info; | return info; | ||||
| } | } | ||||
| @@ -50,14 +50,14 @@ void* ChannelImpl::put(const DeviceTensorND& data) { | |||||
| void ChannelImpl::del(void* handle) { | void ChannelImpl::del(void* handle) { | ||||
| mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | ||||
| m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| } | } | ||||
| void ChannelImpl::swap_in(void* handle) { | void ChannelImpl::swap_in(void* handle) { | ||||
| if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| m_worker.add_task(SwapIn{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| m_buffer.enqueue(SwapIn{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| } | } | ||||
| } | } | ||||
| @@ -65,7 +65,7 @@ void ChannelImpl::swap_out(void* handle) { | |||||
| if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| m_worker.add_task(SwapOut{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| m_buffer.enqueue(SwapOut{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| } | } | ||||
| } | } | ||||
| @@ -73,7 +73,7 @@ void ChannelImpl::drop(void* handle) { | |||||
| if (m_enable_evict & DROP) { | if (m_enable_evict & DROP) { | ||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| m_worker.add_task(Drop{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| m_buffer.enqueue(Drop{reinterpret_cast<TensorInfo*>(handle)}); | |||||
| } | } | ||||
| } | } | ||||
| @@ -88,14 +88,16 @@ SmallVector<void*> ChannelImpl::apply_op( | |||||
| input_infos.reserve(inputs.size()); | input_infos.reserve(inputs.size()); | ||||
| SmallVector<LogicalTensorDesc> input_descs; | SmallVector<LogicalTensorDesc> input_descs; | ||||
| input_descs.reserve(inputs.size()); | input_descs.reserve(inputs.size()); | ||||
| std::unique_lock<decltype(m_mutex)> lock(m_mutex); | |||||
| for (auto i : inputs) { | |||||
| auto info = reinterpret_cast<TensorInfo*>(i); | |||||
| mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
| input_infos.push_back(info); | |||||
| input_descs.push_back(info->desc); | |||||
| { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| for (auto i : inputs) { | |||||
| auto info = reinterpret_cast<TensorInfo*>(i); | |||||
| mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
| input_infos.push_back(info); | |||||
| input_descs.push_back(info->desc); | |||||
| } | |||||
| } | } | ||||
| lock.unlock(); | |||||
| auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | ||||
| ApplyOp cmd{std::move(op)}; | ApplyOp cmd{std::move(op)}; | ||||
| @@ -127,7 +129,7 @@ SmallVector<void*> ChannelImpl::apply_op( | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| m_worker.add_task(std::move(cmd)); | |||||
| m_buffer.enqueue(std::move(cmd)); | |||||
| if (!(validated && validated_bkp) && m_async_level == 1) { | if (!(validated && validated_bkp) && m_async_level == 1) { | ||||
| sync(); | sync(); | ||||
| } else if (m_async_level == 0) { | } else if (m_async_level == 0) { | ||||
| @@ -150,7 +152,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { | |||||
| if (!info->value_fetched) { | if (!info->value_fetched) { | ||||
| mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); | mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); | ||||
| m_waitee = info; | m_waitee = info; | ||||
| m_worker.add_task(GetValue{info}); | |||||
| m_buffer.enqueue(GetValue{info}); | |||||
| m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
| check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
| return info->value_fetched; | return info->value_fetched; | ||||
| @@ -171,6 +173,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { | |||||
| std::unique_lock<decltype(m_mutex)> lock(m_mutex); | std::unique_lock<decltype(m_mutex)> lock(m_mutex); | ||||
| mgb_assert(!m_waitee); | mgb_assert(!m_waitee); | ||||
| m_waitee = info; | m_waitee = info; | ||||
| m_buffer.enqueue(Flush{info}); | |||||
| m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
| check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
| return bool(info->ptr); | return bool(info->ptr); | ||||
| @@ -206,6 +209,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||||
| std::unique_lock<decltype(m_mutex)> lock(m_mutex); | std::unique_lock<decltype(m_mutex)> lock(m_mutex); | ||||
| mgb_assert(!m_waitee); | mgb_assert(!m_waitee); | ||||
| m_waitee = info; | m_waitee = info; | ||||
| m_buffer.enqueue(Flush{info}); | |||||
| m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
| check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
| return bool(info->ptr); | return bool(info->ptr); | ||||
| @@ -215,6 +219,9 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||||
| } | } | ||||
| void ChannelImpl::sync() { | void ChannelImpl::sync() { | ||||
| if (!m_buffer.empty()) { | |||||
| m_buffer.enqueue(Flush{}); | |||||
| } | |||||
| m_worker.wait_all_task_finish(); | m_worker.wait_all_task_finish(); | ||||
| MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
| check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
| @@ -350,6 +357,10 @@ void ChannelImpl::set_drop_flag(bool flag) { | |||||
| } | } | ||||
| } | } | ||||
| void ChannelImpl::set_buffer_length(int length) { | |||||
| m_buffer.set_capacity(length); | |||||
| } | |||||
| void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { | void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { | ||||
| if (!info->ptr && info->evict_type != NONE) { | if (!info->ptr && info->evict_type != NONE) { | ||||
| if (info->evict_type == SWAP) { | if (info->evict_type == SWAP) { | ||||
| @@ -401,6 +412,7 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
| } else if constexpr (std::is_same_v<T, ApplyOp>) { | } else if constexpr (std::is_same_v<T, ApplyOp>) { | ||||
| SmallVector<TensorPtr> tensor_inputs; | SmallVector<TensorPtr> tensor_inputs; | ||||
| tensor_inputs.reserve(cmd.inputs.size()); | tensor_inputs.reserve(cmd.inputs.size()); | ||||
| // refcnt == 1, owners: [TensorInfo::ptr] | |||||
| for (auto i : cmd.inputs) { | for (auto i : cmd.inputs) { | ||||
| if (m_enable_evict && i->evict_type != NONE) { | if (m_enable_evict && i->evict_type != NONE) { | ||||
| if (!i->ptr) { | if (!i->ptr) { | ||||
| @@ -408,9 +420,20 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
| } | } | ||||
| } | } | ||||
| mgb_assert(i->ptr, "Invalid input tensor ptr!"); | mgb_assert(i->ptr, "Invalid input tensor ptr!"); | ||||
| // refcnt ++, owners: [i->ptr, tensor_inputs] | |||||
| tensor_inputs.push_back(i->ptr); | tensor_inputs.push_back(i->ptr); | ||||
| } | } | ||||
| auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs); | |||||
| // Fused by command buffer. @see: CommandBuffer::fuse_del | |||||
| // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. | |||||
| // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. | |||||
| for (auto* del : cmd.dels) { | |||||
| // refcnt --, owners: [tensor_inputs] | |||||
| // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor | |||||
| free(del); | |||||
| } | |||||
| // Here std::move is REQUIRED for removing duplicated references. | |||||
| auto tensor_outputs = OpDef::apply_on_physical_tensor( | |||||
| *cmd.op, std::move(tensor_inputs)); | |||||
| mgb_assert(tensor_outputs.size() == cmd.outputs.size()); | mgb_assert(tensor_outputs.size() == cmd.outputs.size()); | ||||
| for (size_t i = 0; i < tensor_outputs.size(); ++i) { | for (size_t i = 0; i < tensor_outputs.size(); ++i) { | ||||
| produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); | produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); | ||||
| @@ -436,8 +459,12 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
| do_swap_out(cmd.dest); | do_swap_out(cmd.dest); | ||||
| } else if constexpr (std::is_same_v<T, Drop>) { | } else if constexpr (std::is_same_v<T, Drop>) { | ||||
| do_drop(cmd.dest); | do_drop(cmd.dest); | ||||
| } else if constexpr (std::is_same_v<T, Move>) { | |||||
| produce_tensor(cmd.dest, cmd.src->ptr); | |||||
| free(cmd.src); | |||||
| } else { | } else { | ||||
| static_assert(!std::is_same_v<T, T>); | |||||
| static_assert(std::is_same_v<T, Flush> || | |||||
| std::is_same_v<T, Nop>); | |||||
| } | } | ||||
| } catch (...) { | } catch (...) { | ||||
| MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
| @@ -454,7 +481,6 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
| }, cmd); | }, cmd); | ||||
| } | } | ||||
| void ChannelImpl::check_worker_exc_unsafe() { | void ChannelImpl::check_worker_exc_unsafe() { | ||||
| if (m_worker_exc) { | if (m_worker_exc) { | ||||
| std::exception_ptr exc; | std::exception_ptr exc; | ||||
| @@ -462,3 +488,120 @@ void ChannelImpl::check_worker_exc_unsafe() { | |||||
| std::rethrow_exception(exc); | std::rethrow_exception(exc); | ||||
| } | } | ||||
| } | } | ||||
| void ChannelImpl::CommandBuffer::enqueue(Command cmd) { | |||||
| if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) { | |||||
| return; | |||||
| } | |||||
| auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, cmd); | |||||
| mgb_log_debug("%s Enqueued", command_repr.c_str()); | |||||
| m_commands.push_back(std::move(cmd)); | |||||
| auto flush_pos = flush_pos_for(m_commands.back()); | |||||
| flush(flush_pos); | |||||
| } | |||||
| void ChannelImpl::CommandBuffer::flush(Handle pos) { | |||||
| for (auto iter = m_commands.begin(); iter != pos; ++iter) { | |||||
| auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, *iter); | |||||
| mgb_log_debug("%s Flushed", command_repr.c_str()); | |||||
| m_owner->m_worker.add_task(std::move(*iter)); | |||||
| } | |||||
| m_commands.erase(m_commands.begin(), pos); | |||||
| } | |||||
| auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { | |||||
| return std::visit([this](const auto& cmd) { | |||||
| using T = std::decay_t<decltype(cmd)>; | |||||
| if constexpr (std::is_same_v<T, ApplyOp>) { | |||||
| auto* op_type = cmd.op->dyn_typeinfo(); | |||||
| if (op_type == RemoteRecv::typeinfo() || | |||||
| op_type == RemoteSend::typeinfo() || | |||||
| op_type == CollectiveComm::typeinfo() || | |||||
| op_type == opr::InputCallback::typeinfo() || | |||||
| op_type == opr::OutputCallback::typeinfo() || | |||||
| op_type == BackwardGraph::typeinfo()) { | |||||
| return m_commands.end(); | |||||
| } | |||||
| } else if constexpr (std::is_same_v<T, GetValue>) { | |||||
| return m_commands.end(); | |||||
| } else if constexpr (std::is_same_v<T, Flush>) { | |||||
| if (cmd.dest == nullptr) { | |||||
| return m_commands.end(); | |||||
| } | |||||
| auto produce_iter = find_produce(cmd.dest, {m_commands.begin(), m_commands.end()}); | |||||
| if (produce_iter != m_commands.end()) { | |||||
| return produce_iter + 1; | |||||
| } | |||||
| } | |||||
| if (m_commands.size() > m_capacity) { | |||||
| return m_commands.begin() + (m_commands.size() - m_capacity); | |||||
| } | |||||
| return m_commands.begin(); | |||||
| }, cmd); | |||||
| } | |||||
| /** | |||||
| * 1. Find ApplyOp(dest) in buffered commands | |||||
| * 2. Check if there are other usages between ApplyOp and Del, return false if not | |||||
| * 3. Fuse Del into ApplyOp, return true | |||||
| */ | |||||
| bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { | |||||
| auto* dest = cmd.dest; | |||||
| // TODO: eliminate Puts | |||||
| auto begin = m_commands.begin(), end = m_commands.end(); | |||||
| auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){ | |||||
| if (auto* apply = std::get_if<ApplyOp>(&cmd)) { | |||||
| return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0; | |||||
| } | |||||
| return false; | |||||
| }); | |||||
| if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { | |||||
| return false; | |||||
| } | |||||
| mgb_log_debug("%s Fused", cmd.to_string().c_str()); | |||||
| std::get<ApplyOp>(*apply_iter).dels.push_back(dest); | |||||
| return true; | |||||
| } | |||||
| auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range) | |||||
| -> Handle { | |||||
| auto found = range[1]; | |||||
| for (auto iter = range[0]; iter != range[1]; ++iter) { | |||||
| std::visit([&](const auto& cmd) { | |||||
| using T = std::decay_t<decltype(cmd)>; | |||||
| if constexpr (std::is_same_v<T, ApplyOp>) { | |||||
| if (std::count(cmd.inputs.begin(), cmd.inputs.end(), | |||||
| dest) > 0) { | |||||
| found = iter; | |||||
| } | |||||
| } else if constexpr (std::is_same_v<T, GetValue>) { | |||||
| if (cmd.dest == dest) { | |||||
| found = iter; | |||||
| } | |||||
| } else if constexpr (std::is_same_v<T, SwapIn> || | |||||
| std::is_same_v<T, SwapOut> || | |||||
| std::is_same_v<T, Drop>) { | |||||
| //TODO: ignore swap-like commands, just remove them from buffer | |||||
| if (cmd.dest == dest) { | |||||
| found = iter; | |||||
| } | |||||
| } | |||||
| }, *iter); | |||||
| }; | |||||
| return found; | |||||
| } | |||||
| auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) | |||||
| -> Handle { | |||||
| return std::find_if(range[0], range[1], [dest](auto& cmd) { | |||||
| return std::visit([dest](const auto& cmd){ | |||||
| using T = std::decay_t<decltype(cmd)>; | |||||
| if constexpr (std::is_same_v<T, ApplyOp>) { | |||||
| return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0; | |||||
| } else if constexpr (std::is_same_v<T, Put>) { | |||||
| return cmd.dest == dest; | |||||
| } | |||||
| return false; | |||||
| }, cmd); | |||||
| }); | |||||
| } | |||||
| @@ -9,13 +9,15 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include <variant> | |||||
| #include <deque> | |||||
| #include <future> | #include <future> | ||||
| #include <list> | |||||
| #include <unordered_set> | |||||
| #include <variant> | |||||
| #include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
| #include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
| namespace mgb::imperative::interpreter::intl { | namespace mgb::imperative::interpreter::intl { | ||||
| using Handle = Interpreter::Handle; | using Handle = Interpreter::Handle; | ||||
| @@ -58,39 +60,99 @@ struct Put { | |||||
| TensorInfo* dest; | TensorInfo* dest; | ||||
| HostTensorND value; | HostTensorND value; | ||||
| bool no_cache = false; | bool no_cache = false; | ||||
| std::string to_string() const { return ssprintf("Command: Put %p", dest); } | |||||
| }; | }; | ||||
| struct ApplyOp { | struct ApplyOp { | ||||
| std::shared_ptr<OpDef> op; | std::shared_ptr<OpDef> op; | ||||
| SmallVector<TensorInfo*> inputs; | SmallVector<TensorInfo*> inputs; | ||||
| SmallVector<TensorInfo*> outputs; | SmallVector<TensorInfo*> outputs; | ||||
| SmallVector<TensorInfo*> dels; | |||||
| std::string to_string() const { | |||||
| std::string builder{"Command: ApplyOp {"}; | |||||
| builder += "inputs ["; | |||||
| for (auto* input : inputs) { | |||||
| builder += ssprintf("%p, ", input); | |||||
| } | |||||
| builder += "], outputs ["; | |||||
| for (auto* output : outputs) { | |||||
| builder += ssprintf("%p, ", output); | |||||
| } | |||||
| builder += "], dels ["; | |||||
| for (auto* del : dels) { | |||||
| builder += ssprintf("%p, ", del); | |||||
| } | |||||
| builder += "]"; | |||||
| return builder; | |||||
| } | |||||
| }; | }; | ||||
| struct Del { | struct Del { | ||||
| TensorInfo* dest; | TensorInfo* dest; | ||||
| std::string to_string() const { return ssprintf("Command: Del %p", dest); } | |||||
| }; | }; | ||||
| struct GetValue { | struct GetValue { | ||||
| TensorInfo* dest; | TensorInfo* dest; | ||||
| }; | |||||
| std::string to_string() const { | |||||
| return ssprintf("Command: GetValue %p", dest); | |||||
| } | |||||
| }; | |||||
| struct SwapIn { | struct SwapIn { | ||||
| TensorInfo* dest; | TensorInfo* dest; | ||||
| std::string to_string() const { | |||||
| return ssprintf("Command: SwapIn %p", dest); | |||||
| } | |||||
| }; | }; | ||||
| struct SwapOut { | struct SwapOut { | ||||
| TensorInfo* dest; | TensorInfo* dest; | ||||
| std::string to_string() const { | |||||
| return ssprintf("Command: SwapOut %p", dest); | |||||
| } | |||||
| }; | }; | ||||
| struct Drop { | struct Drop { | ||||
| TensorInfo* dest; | TensorInfo* dest; | ||||
| std::string to_string() const { | |||||
| return ssprintf("Command: Drop %p", dest); | |||||
| } | |||||
| }; | |||||
| struct Move { | |||||
| TensorInfo* src; | |||||
| TensorInfo* dest; | |||||
| std::string to_string() const { | |||||
| return ssprintf("Command: Move %s to %s", | |||||
| src->desc.layout.to_string().c_str(), | |||||
| dest->desc.layout.to_string().c_str()); | |||||
| } | |||||
| }; | }; | ||||
| struct Flush { | |||||
| TensorInfo* dest = nullptr; | |||||
| std::string to_string() const { | |||||
| return ssprintf("Command: Flush %p", dest); | |||||
| } | |||||
| }; | |||||
| struct Nop { | |||||
| std::string to_string() const { return "Command: Nop"; } | |||||
| }; | |||||
| using Command = std::variant<Put, | using Command = std::variant<Put, | ||||
| ApplyOp, | ApplyOp, | ||||
| Del, | Del, | ||||
| GetValue, | GetValue, | ||||
| SwapIn, | SwapIn, | ||||
| SwapOut, | SwapOut, | ||||
| Drop>; | |||||
| Drop, | |||||
| Move, | |||||
| Flush, | |||||
| Nop>; | |||||
| struct ChannelImpl : Interpreter::Channel { | struct ChannelImpl : Interpreter::Channel { | ||||
| ChannelImpl() : m_worker(this) {} | |||||
| ChannelImpl() : m_worker(this), m_buffer(this) {} | |||||
| ~ChannelImpl() override; | ~ChannelImpl() override; | ||||
| Handle put(const HostTensorND& value, bool no_cache) override; | Handle put(const HostTensorND& value, bool no_cache) override; | ||||
| @@ -116,6 +178,7 @@ struct ChannelImpl : Interpreter::Channel { | |||||
| void close() override; | void close() override; | ||||
| void set_swap_flag(bool) override; | void set_swap_flag(bool) override; | ||||
| void set_drop_flag(bool) override; | void set_drop_flag(bool) override; | ||||
| void set_buffer_length(int) override; | |||||
| void config_async_level(int level) override; | void config_async_level(int level) override; | ||||
| int get_async_level() override; | int get_async_level() override; | ||||
| @@ -174,7 +237,56 @@ private: | |||||
| std::mutex mtx; | std::mutex mtx; | ||||
| std::unordered_map<TensorInfo*, TensorInfoPtr> tmap; | std::unordered_map<TensorInfo*, TensorInfoPtr> tmap; | ||||
| }m_st; | }m_st; | ||||
| /** | |||||
| * Buf a command window for following fuse | |||||
| * example: | |||||
| * --------------------------------------------------------------------- | |||||
| * | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} | | |||||
| * --------------------------------------------------------------------- | |||||
| * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} | | |||||
| * --------------------------------------------------------------------- | |||||
| * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... | | |||||
| * --------------------------------------------------------------------- | |||||
| * Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task | |||||
| */ | |||||
| struct CommandBuffer { | |||||
| CommandBuffer(ChannelImpl* owner) : m_owner(owner) { | |||||
| int capacity = 3; | |||||
| if(const char* capacity_str = MGB_GETENV("MEGENGINE_COMMAND_BUFFER_LENGTH")) { | |||||
| capacity = atoi(capacity_str); | |||||
| } | |||||
| set_capacity(capacity); | |||||
| } | |||||
| void enqueue(Command cmd); | |||||
| bool empty() const { | |||||
| return m_commands.empty(); | |||||
| } | |||||
| void set_capacity(int capacity) { | |||||
| mgb_assert(capacity >= 0 && capacity < 100, "invalid command buffer length"); | |||||
| m_capacity = capacity; | |||||
| } | |||||
| private: | |||||
| ChannelImpl* m_owner; | |||||
| size_t m_capacity; | |||||
| std::deque<Command> m_commands; | |||||
| using Handle = decltype(m_commands)::iterator; | |||||
| // [begin, end) | |||||
| using Range = std::array<Handle, 2>; | |||||
| // Launch commands in range [m_commands.begin(), pos) | |||||
| void flush(Handle pos); | |||||
| // Select flush position for incoming cmd | |||||
| Handle flush_pos_for(const Command& cmd); | |||||
| // Fuse del command into suitable ApplyOp | |||||
| bool fuse_del(const Del& cmd); | |||||
| // Returns the last handle that dest is used within range. If dest is not used, returns range[1] | |||||
| Handle find_last_usage(TensorInfo* dest, Range range); | |||||
| // Returns the produce position of dest. If not found, returns range[1] | |||||
| Handle find_produce(TensorInfo* dest, Range range); | |||||
| } m_buffer; | |||||
| //! config whether raise error exactly when invoking op. | //! config whether raise error exactly when invoking op. | ||||
| //! level 2: both device and user side errors are async; | //! level 2: both device and user side errors are async; | ||||
| //! level 1: user side errors are sync; | //! level 1: user side errors are sync; | ||||
| @@ -32,8 +32,8 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node( | |||||
| SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| const SmallVector<TensorPtr>& inputs) { | |||||
| return def.trait()->apply_on_physical_tensor(def, inputs); | |||||
| SmallVector<TensorPtr> inputs) { | |||||
| return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | |||||
| } | } | ||||
| VarNodeArray OpDef::apply_on_var_node( | VarNodeArray OpDef::apply_on_var_node( | ||||
| @@ -17,17 +17,17 @@ namespace mgb { | |||||
| namespace imperative { | namespace imperative { | ||||
| namespace detail { | namespace detail { | ||||
| template<typename Signature> | |||||
| template <typename Signature> | |||||
| struct OpMeth; | struct OpMeth; | ||||
| template<typename RType, typename ...Args> | |||||
| struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> { | |||||
| template <typename RType, typename... Args> | |||||
| struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> { | |||||
| using Base = thin_function<RType(Args...)>; | using Base = thin_function<RType(Args...)>; | ||||
| using Base::Base; | using Base::Base; | ||||
| RType operator()(Args... args) const { | RType operator()(Args... args) const { | ||||
| if (!this->Base::operator bool()) { | if (!this->Base::operator bool()) { | ||||
| mgb_throw(MegBrainError, "Not Implemented"); | mgb_throw(MegBrainError, "Not Implemented"); | ||||
| } | } | ||||
| return this->Base::operator ()(args...); | |||||
| return this->Base::operator()(std::forward<Args>(args)...); | |||||
| } | } | ||||
| }; | }; | ||||
| template<typename T> | template<typename T> | ||||
| @@ -56,7 +56,7 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||||
| return opr->usable_output(); | return opr->usable_output(); | ||||
| } | } | ||||
| }; | }; | ||||
| } // detail | |||||
| } // namespace detail | |||||
| using OpDefMaker = detail::OpMeth< | using OpDefMaker = detail::OpMeth< | ||||
| decltype(OpDef::make_from_op_node)>; | decltype(OpDef::make_from_op_node)>; | ||||
| @@ -56,17 +56,15 @@ protected: | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| AsyncReleaser() { | |||||
| EventPool::without_timer(); | |||||
| } | |||||
| public: | public: | ||||
| static AsyncReleaser* inst() { | static AsyncReleaser* inst() { | ||||
| static AsyncReleaser releaser; | static AsyncReleaser releaser; | ||||
| return &releaser; | return &releaser; | ||||
| } | } | ||||
| ~AsyncReleaser() { m_waiter.wait_task_queue_empty(); } | |||||
| ~AsyncReleaser() { | |||||
| m_waiter.wait_task_queue_empty(); | |||||
| } | |||||
| void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); } | void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); } | ||||
| @@ -85,8 +83,6 @@ public: | |||||
| class CompNodeSyncManager : public CompNodeDepedentObject { | class CompNodeSyncManager : public CompNodeDepedentObject { | ||||
| ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event; | ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event; | ||||
| std::mutex m_mtx; | std::mutex m_mtx; | ||||
| private: | |||||
| static CompNodeSyncManager mgr; | |||||
| public: | public: | ||||
| std::shared_ptr<void> on_comp_node_finalize() override { | std::shared_ptr<void> on_comp_node_finalize() override { | ||||
| MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
| @@ -94,8 +90,9 @@ public: | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| static CompNodeSyncManager* inst() { | |||||
| return &mgr; | |||||
| static CompNodeSyncManager& inst() { | |||||
| static CompNodeSyncManager sl_inst; | |||||
| return sl_inst; | |||||
| } | } | ||||
| CompNode::Event* get_or_create_event(Blob* blob) { | CompNode::Event* get_or_create_event(Blob* blob) { | ||||
| @@ -113,7 +110,6 @@ public: | |||||
| m_blob2event.erase(blob); | m_blob2event.erase(blob); | ||||
| } | } | ||||
| }; | }; | ||||
| CompNodeSyncManager CompNodeSyncManager::mgr; | |||||
| // Cache for small blobs | // Cache for small blobs | ||||
| // 1. A blob has to be seen twice (within a window) to be eligible for cache | // 1. A blob has to be seen twice (within a window) to be eligible for cache | ||||
| @@ -236,9 +232,12 @@ struct MultiCNConstTensorCache : CompNodeDepedentObject { | |||||
| MGB_LOCK_GUARD(mtx); | MGB_LOCK_GUARD(mtx); | ||||
| return cn2cache[hv.comp_node()].lookup(hv); | return cn2cache[hv.comp_node()].lookup(hv); | ||||
| } | } | ||||
| }; | |||||
| MultiCNConstTensorCache const_tensor_cache; | |||||
| static MultiCNConstTensorCache& inst() { | |||||
| static MultiCNConstTensorCache sl_inst; | |||||
| return sl_inst; | |||||
| } | |||||
| }; | |||||
| } // namespace | } // namespace | ||||
| @@ -246,20 +245,26 @@ void EventDeleter::operator()(CompNode::Event* event) { | |||||
| EventPool::without_timer().free(event); | EventPool::without_timer().free(event); | ||||
| } | } | ||||
| namespace { | |||||
| std::atomic_uint64_t next_blob_id = 0; | |||||
| } | |||||
| Blob::Blob(const DeviceTensorStorage& s): | Blob::Blob(const DeviceTensorStorage& s): | ||||
| m_comp_node{s.comp_node()}, m_storage{s.raw_storage()}, | m_comp_node{s.comp_node()}, m_storage{s.raw_storage()}, | ||||
| m_size{s.size()} { | m_size{s.size()} { | ||||
| m_id = next_blob_id++; | |||||
| BlobManager::inst()->register_blob(this); | BlobManager::inst()->register_blob(this); | ||||
| } | } | ||||
| Blob::Blob(CompNode cn, size_t sz): | Blob::Blob(CompNode cn, size_t sz): | ||||
| m_comp_node{cn}, m_storage{}, m_size{sz} { | m_comp_node{cn}, m_storage{}, m_size{sz} { | ||||
| m_id = next_blob_id++; | |||||
| BlobManager::inst()->register_blob(this); | BlobManager::inst()->register_blob(this); | ||||
| } | } | ||||
| Blob::~Blob() { | Blob::~Blob() { | ||||
| BlobManager::inst()->unregister_blob(this); | BlobManager::inst()->unregister_blob(this); | ||||
| CompNodeSyncManager::inst()->remove(this); | |||||
| CompNodeSyncManager::inst().remove(this); | |||||
| } | } | ||||
| const Blob::RawStorage& Blob::storage() { | const Blob::RawStorage& Blob::storage() { | ||||
| @@ -302,7 +307,7 @@ Tensor::Tensor(const BlobPtr blob, const size_t offset, const TensorLayout& layo | |||||
| : m_layout{layout}, m_blob{blob}, m_offset{offset} {} | : m_layout{layout}, m_blob{blob}, m_offset{offset} {} | ||||
| TensorPtr Tensor::make(const HostTensorND& hv) { | TensorPtr Tensor::make(const HostTensorND& hv) { | ||||
| auto&& blob = const_tensor_cache.lookup(hv); | |||||
| auto&& blob = MultiCNConstTensorCache::inst().lookup(hv); | |||||
| if (blob) { | if (blob) { | ||||
| return make(std::forward<decltype(blob)>(blob), hv.layout(), hv); | return make(std::forward<decltype(blob)>(blob), hv.layout(), hv); | ||||
| } | } | ||||
| @@ -366,13 +371,17 @@ void Tensor::add_release_callback(CompNode cn) { | |||||
| } | } | ||||
| CompNode::Event* Tensor::get_or_create_event() { | CompNode::Event* Tensor::get_or_create_event() { | ||||
| auto e = CompNodeSyncManager::inst()->get_or_create_event(m_blob.get()); | |||||
| auto e = CompNodeSyncManager::inst().get_or_create_event(m_blob.get()); | |||||
| e->record(); | e->record(); | ||||
| return e; | return e; | ||||
| } | } | ||||
| void Tensor::_static_init() { | |||||
| void Tensor::static_initialize() { | |||||
| EventPool::with_timer(); | |||||
| EventPool::without_timer(); | EventPool::without_timer(); | ||||
| AsyncReleaser::inst(); | |||||
| CompNodeSyncManager::inst(); | |||||
| MultiCNConstTensorCache::inst(); | |||||
| } | } | ||||
| } // namespace imperative | } // namespace imperative | ||||
| @@ -117,7 +117,7 @@ void Profiler::start(uint32_t flags) { | |||||
| auto hook_apply_on_var_node = | auto hook_apply_on_var_node = | ||||
| make_shared_hook(&trait.apply_on_var_node); | make_shared_hook(&trait.apply_on_var_node); | ||||
| hook_apply_on_physical_tensor->apply_hook([this, flags] | hook_apply_on_physical_tensor->apply_hook([this, flags] | ||||
| (auto&& apply, const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
| (auto&& apply, const OpDef& def, SmallVector<TensorPtr> inputs) { | |||||
| auto shape2vector = [](const TensorShape& shape) { | auto shape2vector = [](const TensorShape& shape) { | ||||
| std::vector<size_t> vector_shape; | std::vector<size_t> vector_shape; | ||||
| for (size_t i = 0; i < shape.ndim; i++) { | for (size_t i = 0; i < shape.ndim; i++) { | ||||
| @@ -11,6 +11,7 @@ | |||||
| #include "./proxy_graph.h" | #include "./proxy_graph.h" | ||||
| #include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -70,11 +71,34 @@ void exec(const OpDef& def, | |||||
| SmallVector<TensorPtr> | SmallVector<TensorPtr> | ||||
| apply_on_physical_tensor(const OpDef& def, | apply_on_physical_tensor(const OpDef& def, | ||||
| const SmallVector<TensorPtr>& inputs) { | |||||
| auto desc = infer_output_attrs(def, inputs); | |||||
| SmallVector<TensorPtr> outputs; | |||||
| for (auto&& i : desc) { | |||||
| outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||||
| SmallVector<TensorPtr> inputs) { | |||||
| auto output_descs = infer_output_attrs(def, inputs); | |||||
| SmallVector<TensorPtr> outputs(output_descs.size(), {}); | |||||
| for (size_t i = 0; i < outputs.size(); i++) { | |||||
| auto& output = outputs[i]; | |||||
| auto& output_desc = output_descs[i]; | |||||
| if (def.same_type<Elemwise>()) { | |||||
| for (size_t j = 0; j < inputs.size(); j++) { | |||||
| // TODO: reindex inputs to support inplace exprs like 'y = x op x'. | |||||
| auto& input = inputs[j]; | |||||
| // Because we pass inputs by value, if input and input->blob() are all unique, | |||||
| // their ownerships are on the stack, thus we can reuse them safely. | |||||
| // @see: interpreter::intl::ChannelImpl::process_one_task | |||||
| if (input.unique() && input->blob().unique() && input->blob()->storage().unique() && | |||||
| input->layout().dtype == output_desc.layout.dtype && | |||||
| input->layout().eq_layout(output_desc.layout) && | |||||
| input->comp_node() == output_desc.comp_node) { | |||||
| static std::atomic_llong inplace_count = 0; | |||||
| mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld", | |||||
| output_desc.layout.to_string().c_str(), ++inplace_count); | |||||
| output = Tensor::make(input->blob(), input->layout(), input->offset()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!output) { | |||||
| output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||||
| } | |||||
| } | } | ||||
| exec(def, inputs, outputs); | exec(def, inputs, outputs); | ||||
| return outputs; | return outputs; | ||||
| @@ -44,6 +44,7 @@ struct Interpreter { | |||||
| virtual void close() = 0; | virtual void close() = 0; | ||||
| virtual void set_swap_flag(bool) = 0; | virtual void set_swap_flag(bool) = 0; | ||||
| virtual void set_drop_flag(bool) = 0; | virtual void set_drop_flag(bool) = 0; | ||||
| virtual void set_buffer_length(int) = 0; | |||||
| virtual void config_async_level(int level) = 0; | virtual void config_async_level(int level) = 0; | ||||
| virtual int get_async_level() = 0; | virtual int get_async_level() = 0; | ||||
| @@ -38,7 +38,7 @@ public: | |||||
| static SmallVector<TensorPtr> apply_on_physical_tensor( | static SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| const SmallVector<TensorPtr>& inputs); | |||||
| SmallVector<TensorPtr> inputs); | |||||
| static cg::VarNodeArray apply_on_var_node( | static cg::VarNodeArray apply_on_var_node( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| @@ -46,11 +46,16 @@ public: | |||||
| size_t size() const { | size_t size() const { | ||||
| return m_size; | return m_size; | ||||
| } | } | ||||
| size_t id() const { | |||||
| return m_id; | |||||
| } | |||||
| private: | private: | ||||
| friend class BlobManagerImpl; | friend class BlobManagerImpl; | ||||
| CompNode m_comp_node; | CompNode m_comp_node; | ||||
| mutable RawStorage m_storage; | mutable RawStorage m_storage; | ||||
| size_t m_size = 0; | size_t m_size = 0; | ||||
| size_t m_id; | |||||
| }; | }; | ||||
| struct EventDeleter { | struct EventDeleter { | ||||
| @@ -134,8 +139,7 @@ public: | |||||
| // Make sure all static objects required to destruct a tensor has completed | // Make sure all static objects required to destruct a tensor has completed | ||||
| // construction. All static storage duration object that holds tensors must | // construction. All static storage duration object that holds tensors must | ||||
| // call this method before their constructors completes. | // call this method before their constructors completes. | ||||
| static void _static_init(); | |||||
| static void static_initialize(); | |||||
| private: | private: | ||||
| TensorLayout m_layout; | TensorLayout m_layout; | ||||
| @@ -19,7 +19,7 @@ namespace proxy_graph_detail { | |||||
| SmallVector<TensorPtr> | SmallVector<TensorPtr> | ||||
| apply_on_physical_tensor(const OpDef& def, | apply_on_physical_tensor(const OpDef& def, | ||||
| const SmallVector<TensorPtr>& inputs); | |||||
| SmallVector<TensorPtr> inputs); | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | ||||
| const SmallVector<LogicalTensorDesc>& inputs); | const SmallVector<LogicalTensorDesc>& inputs); | ||||