From cfad9a5df3cc8a5829b57636bff52858b99a3f06 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 3 Dec 2021 11:50:18 +0800 Subject: [PATCH] fix(mgb/cambricon): fix magicmind runtime opr when set workspace point second time GitOrigin-RevId: 1ac9d0eabad312dcbcffacd7f35522a37884ddb3 --- src/cambricon/impl/magicmind_runtime_opr.cpp | 18 ++++++++++++++---- .../megbrain/cambricon/magicmind_runtime_opr.h | 1 + src/cambricon/test/magicmind_runtime_opr.cpp | 1 + 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/cambricon/impl/magicmind_runtime_opr.cpp b/src/cambricon/impl/magicmind_runtime_opr.cpp index c25b0b8d..7bf5c534 100644 --- a/src/cambricon/impl/magicmind_runtime_opr.cpp +++ b/src/cambricon/impl/magicmind_runtime_opr.cpp @@ -168,7 +168,8 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr( m_allocator{std::move(allocator)}, m_engine{nullptr}, m_context{nullptr}, - m_model{std::move(model)} { + m_model{std::move(model)}, + m_current_ptr{nullptr} { mgb_assert( inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON, "MagicMindRuntimeOpr can only be used on cambricon comp node; " @@ -230,8 +231,18 @@ void MagicMindRuntimeOpr::scn_do_execute() { MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(output(i)->shape()))); MM_CHECK(tensor->SetData(output(i)->dev_tensor().raw_ptr())); } - auto size = output().back()->dev_tensor().layout().span().dist_byte(); - MM_CHECK(m_context->SetWorkspace(output().back()->dev_tensor().raw_ptr(), size)); + if (m_current_ptr == nullptr) { + auto size = output().back()->dev_tensor().layout().span().dist_byte(); + m_current_ptr = output().back()->dev_tensor().raw_ptr(); + MM_CHECK(m_context->SetWorkspace(m_current_ptr, size)); + } else { + auto current_ptr = output().back()->dev_tensor().raw_ptr(); + mgb_assert( + current_ptr == m_current_ptr, + "workspace has been changed, the execution context should be " + "reconstructed, but now this is not supported (got:%p,prev:%p)", + current_ptr, m_current_ptr); + } MM_CHECK(m_context->Enqueue(inputs, outputs, cnrt_env.queue)); for (auto&& i : inputs) { i->Destroy(); @@ -293,7 +304,6 @@ void MagicMindRuntimeOpr::get_output_var_shape( false, "static shape infer for MagicMindRuntimeOpr(%s) failed", cname()); } - return; for (auto&& i : inputs) { i->Destroy(); } diff --git a/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h b/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h index b72cf871..0fd8a318 100644 --- a/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h +++ b/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h @@ -93,6 +93,7 @@ private: IEnginePtr m_engine; mutable IContextPtr m_context; IModelPtr m_model; + dt_byte* m_current_ptr; }; } // namespace opr diff --git a/src/cambricon/test/magicmind_runtime_opr.cpp b/src/cambricon/test/magicmind_runtime_opr.cpp index 4a7a4460..7ddeb111 100644 --- a/src/cambricon/test/magicmind_runtime_opr.cpp +++ b/src/cambricon/test/magicmind_runtime_opr.cpp @@ -642,6 +642,7 @@ TEST(TestMagicMindRuntimeOpr, GraphShapeMutable) { auto func = graph->compile( {make_callback_copy(out1, o1), make_callback_copy(out2, o2)}); func->execute(); + func->execute(); HostTensorND o1_mm(cn, mkshp(no, co, ho, wo), dtype::Float32()), o2_mm(cn, mkshp(no, co, ho, wo), dtype::Float32()); std::memcpy(