From 2b8e7940b6fd6a1ef0ba6d43dd1d406b23a2aa2c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 10 Dec 2021 14:51:36 +0800 Subject: [PATCH] fix(lite/cambricon): fix cambricon models which have multiple comp node GitOrigin-RevId: 624fd7f0ce7cadcaabf5584b10619532a7f5a231 --- imperative/src/impl/ops/magicmind_runtime.cpp | 9 +++++++++ lite/src/mge/network_impl.cpp | 11 +++++++++++ src/cambricon/impl/magicmind_runtime_opr.cpp | 4 +++- src/cambricon/impl/magicmind_runtime_opr.sereg.h | 4 ++++ .../megbrain/cambricon/magicmind_runtime_opr.h | 2 ++ src/cambricon/test/magicmind_runtime_opr.cpp | 2 ++ src/core/impl/comp_node/cpu/comp_node.cpp | 5 +++-- 7 files changed, 34 insertions(+), 3 deletions(-) diff --git a/imperative/src/impl/ops/magicmind_runtime.cpp b/imperative/src/impl/ops/magicmind_runtime.cpp index 9bb54605..3ed295d2 100644 --- a/imperative/src/impl/ops/magicmind_runtime.cpp +++ b/imperative/src/impl/ops/magicmind_runtime.cpp @@ -20,11 +20,20 @@ namespace { namespace magicmind_runtime { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { +#if CNRT_MAJOR_VERSION >= 5 auto&& op = static_cast(def); SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); OperatorNodeConfig config{op.make_name()}; return opr::MagicMindRuntimeOpr::make( op.buf.c_str(), op.buf_size, symbol_var_inputs, config); +#else + mgb_assert( + false, + "Magicmind runtime opr is disabled at compile time, the reason of which is " + "the version of cnrt runtime is lower than 5.0. Please check the version " + "of your cambricon toolkit, and recompile megengine."); + return SymbolVar{}; +#endif } OP_TRAIT_REG(MagicMindRuntime, MagicMindRuntime) .apply_on_var_node(apply_on_var_node) diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index c7db5f4c..b844f561 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -129,6 +129,17 @@ void NetworkImplDft::application_config() { loc.stream = m_nr_threads; } }; + //! currently not set Locator type because a cambricon mgb model is a + //! cross-compnode graph + } else if (device_type == LiteDeviceType::LITE_CAMBRICON) { + m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { + if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) { + loc.device = m_compnode_locator.device; + loc.stream = m_compnode_locator.stream; + } else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) { + loc.stream = m_nr_threads; + } + }; } else { m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { loc = m_compnode_locator; diff --git a/src/cambricon/impl/magicmind_runtime_opr.cpp b/src/cambricon/impl/magicmind_runtime_opr.cpp index 7bf5c534..569e0e7f 100644 --- a/src/cambricon/impl/magicmind_runtime_opr.cpp +++ b/src/cambricon/impl/magicmind_runtime_opr.cpp @@ -14,6 +14,7 @@ #include "megbrain/comp_node_env.h" #if MGB_CAMBRICON +#if CNRT_MAJOR_VERSION >= 5 using namespace mgb; using namespace opr; @@ -168,7 +169,7 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr( m_allocator{std::move(allocator)}, m_engine{nullptr}, m_context{nullptr}, - m_model{std::move(model)}, + m_model{std::move(model)}, m_current_ptr{nullptr} { mgb_assert( inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON, @@ -387,6 +388,7 @@ SymbolVarArray MagicMindRuntimeOpr::make( return make(std::move(model), std::move(cambricon_allocator), src, config); } +#endif // CNRT_MAJOR_VERSION #endif // MGB_CAMBRICON // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/cambricon/impl/magicmind_runtime_opr.sereg.h b/src/cambricon/impl/magicmind_runtime_opr.sereg.h index 67b0ab3d..ec2de366 100644 --- a/src/cambricon/impl/magicmind_runtime_opr.sereg.h +++ b/src/cambricon/impl/magicmind_runtime_opr.sereg.h @@ -12,6 +12,8 @@ #include "megbrain/cambricon/magicmind_runtime_opr.h" #include "megbrain/serialization/sereg.h" +#if CNRT_MAJOR_VERSION >= 5 + namespace mgb { namespace serialization { @@ -62,4 +64,6 @@ MGB_REG_OPR_SHALLOW_COPY(MagicMindRuntimeOpr, opr_shallow_copy_magicmind_runtime } // namespace opr } // namespace mgb +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h b/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h index 0fd8a318..0c80d2c4 100644 --- a/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h +++ b/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h @@ -15,6 +15,7 @@ #include "megbrain/serialization/file.h" #if MGB_CAMBRICON +#if CNRT_MAJOR_VERSION >= 5 #include #include "interface_runtime.h" @@ -99,6 +100,7 @@ private: } // namespace opr } // namespace mgb +#endif // CNRT_MAJOR_VERSION #endif // MGB_CAMBRICON // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/cambricon/test/magicmind_runtime_opr.cpp b/src/cambricon/test/magicmind_runtime_opr.cpp index 7ddeb111..38ce0d9b 100644 --- a/src/cambricon/test/magicmind_runtime_opr.cpp +++ b/src/cambricon/test/magicmind_runtime_opr.cpp @@ -17,6 +17,7 @@ #include "megbrain/test/helper.h" #if MGB_CAMBRICON +#if CNRT_MAJOR_VERSION >= 5 #include "megbrain/cambricon/magicmind_runtime_opr.h" @@ -827,6 +828,7 @@ TEST(TestMagicMindRuntimeOpr, CrossCNCopy) { MGB_ASSERT_TENSOR_NEAR(o2, o2_mm, 1e-4); } +#endif #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index 3beab27c..5362321c 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -1097,7 +1097,8 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp mgb_throw_if( type != CompNode::DeviceType::CPU && type != CompNode::DeviceType::CUDA - && type != CompNode::DeviceType::ATLAS + && type != CompNode::DeviceType::ATLAS && + type != CompNode::DeviceType::CAMBRICON , MegBrainError, "currently CPU can only wait for CPU, CUDA, ATLAS" @@ -1116,7 +1117,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp #else mgb_throw( MegBrainError, - "Cambricon comp_node used but MGB_CAMBRICON not enabled"); + "Cambricon comp_node used but CAMBRICON BUILD not enabled"); #endif }