GitOrigin-RevId: f79155e5c3
tags/v1.4.0-rc1
| @@ -67,6 +67,29 @@ void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) { | |||||
| blob->m_storage = storage.raw_storage(); | blob->m_storage = storage.raw_storage(); | ||||
| } | } | ||||
| DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) { | |||||
| DeviceTensorND dev_tensor; | |||||
| if (!m_enable) { | |||||
| dev_tensor = alloc_workspace(cn, layout); | |||||
| } else { | |||||
| MGB_TRY{ dev_tensor = alloc_workspace(cn, layout); } | |||||
| MGB_CATCH(MemAllocError&, { | |||||
| mgb_log_warn("memory allocation failed for workspace; try defragmenting"); | |||||
| defrag(cn); | |||||
| dev_tensor = alloc_workspace(cn, layout); | |||||
| }); | |||||
| } | |||||
| return dev_tensor; | |||||
| }; | |||||
| DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout) { | |||||
| DeviceTensorStorage storage(cn); | |||||
| storage.ensure_size(layout.dtype.size(layout.total_nr_elems())); | |||||
| DeviceTensorND dev_tensor; | |||||
| dev_tensor.reset(storage, layout); | |||||
| return dev_tensor; | |||||
| } | |||||
| void BlobManagerImpl::defrag(const CompNode& cn) { | void BlobManagerImpl::defrag(const CompNode& cn) { | ||||
| BlobSetWithMux* blobs_set_ptr; | BlobSetWithMux* blobs_set_ptr; | ||||
| { | { | ||||
| @@ -136,6 +159,9 @@ struct BlobManagerStub : BlobManager { | |||||
| void alloc_with_defrag(Blob* blob, size_t size) { | void alloc_with_defrag(Blob* blob, size_t size) { | ||||
| mgb_assert(0, "prohibited after global variable destruction"); | mgb_assert(0, "prohibited after global variable destruction"); | ||||
| }; | }; | ||||
| DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) { | |||||
| mgb_assert(0, "prohibited after global variable destruction"); | |||||
| }; | |||||
| void register_blob(Blob* blob) { | void register_blob(Blob* blob) { | ||||
| mgb_assert(0, "prohibited after global variable destruction"); | mgb_assert(0, "prohibited after global variable destruction"); | ||||
| }; | }; | ||||
| @@ -45,11 +45,15 @@ class BlobManagerImpl final: public BlobManager { | |||||
| void alloc_direct(Blob* blob, size_t size); | void alloc_direct(Blob* blob, size_t size); | ||||
| DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); | |||||
| public: | public: | ||||
| static BlobManager* inst(); | static BlobManager* inst(); | ||||
| void alloc_with_defrag(Blob* blob, size_t size) override; | void alloc_with_defrag(Blob* blob, size_t size) override; | ||||
| DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) override; | |||||
| void register_blob(Blob* blob) override; | void register_blob(Blob* blob) override; | ||||
| void unregister_blob(Blob* blob) override; | void unregister_blob(Blob* blob) override; | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| #include "../dnn_op_helper.h" | #include "../dnn_op_helper.h" | ||||
| #include "../blob_manager_impl.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -102,11 +103,16 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, | const OpDef& def, | ||||
| const SmallVector<TensorPtr>& inputs) { | const SmallVector<TensorPtr>& inputs) { | ||||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||||
| SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | ||||
| TensorShapeArray inp_shapes(inputs.size()); | |||||
| for (unsigned i = 0; i < inputs.size(); ++i){ | for (unsigned i = 0; i < inputs.size(); ++i){ | ||||
| inp_tensornds[i] = inputs[i]->dev_tensor(); | inp_tensornds[i] = inputs[i]->dev_tensor(); | ||||
| inp_shapes[i] = inputs[i]->layout(); | |||||
| } | } | ||||
| SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}}; | |||||
| TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); | |||||
| DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(inp_tensornds[0].comp_node(), {shape, inp_tensornds[0].layout().dtype}); | |||||
| SmallVector<DeviceTensorND> oup_tensornds = {out}; | |||||
| apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | ||||
| return {Tensor::make(oup_tensornds[0])}; | return {Tensor::make(oup_tensornds[0])}; | ||||
| } | } | ||||
| @@ -555,10 +555,7 @@ void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs) { | |||||
| if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | ||||
| // alloc workspace | // alloc workspace | ||||
| TensorLayout layout{var->shape(), var->dtype(), var->format()}; | TensorLayout layout{var->shape(), var->dtype(), var->format()}; | ||||
| DeviceTensorStorage storage; | |||||
| storage.comp_node(var->comp_node()) | |||||
| .ensure_size(layout.dtype.size(layout.total_nr_elems())); | |||||
| var->m_dev_tensor.reset(storage, layout); | |||||
| var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout); | |||||
| } else { | } else { | ||||
| mgb_assert(j < outputs.size()); | mgb_assert(j < outputs.size()); | ||||
| auto &&tensor = outputs[j]; | auto &&tensor = outputs[j]; | ||||
| @@ -24,6 +24,8 @@ public: | |||||
| virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; | virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; | ||||
| virtual DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) = 0; | |||||
| virtual void register_blob(Blob* blob) = 0; | virtual void register_blob(Blob* blob) = 0; | ||||
| virtual void unregister_blob(Blob* blob) = 0; | virtual void unregister_blob(Blob* blob) = 0; | ||||