Browse Source

fix(lite/cambricon): fix cambricon models which have multiple comp node

GitOrigin-RevId: 624fd7f0ce
tags/v1.7.2.m1
Megvii Engine Team XindaH 4 years ago
parent
commit
2b8e7940b6
7 changed files with 34 additions and 3 deletions
  1. +9
    -0
      imperative/src/impl/ops/magicmind_runtime.cpp
  2. +11
    -0
      lite/src/mge/network_impl.cpp
  3. +3
    -1
      src/cambricon/impl/magicmind_runtime_opr.cpp
  4. +4
    -0
      src/cambricon/impl/magicmind_runtime_opr.sereg.h
  5. +2
    -0
      src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h
  6. +2
    -0
      src/cambricon/test/magicmind_runtime_opr.cpp
  7. +3
    -2
      src/core/impl/comp_node/cpu/comp_node.cpp

+ 9
- 0
imperative/src/impl/ops/magicmind_runtime.cpp View File

@@ -20,11 +20,20 @@ namespace {
namespace magicmind_runtime {

auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
#if CNRT_MAJOR_VERSION >= 5
auto&& op = static_cast<const MagicMindRuntime&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
OperatorNodeConfig config{op.make_name()};
return opr::MagicMindRuntimeOpr::make(
op.buf.c_str(), op.buf_size, symbol_var_inputs, config);
#else
mgb_assert(
false,
"Magicmind runtime opr is disabled at compile time, the reason of which is "
"the version of cnrt runtime is lower than 5.0. Please check the version "
"of your cambricon toolkit, and recompile megengine.");
return SymbolVar{};
#endif
}
OP_TRAIT_REG(MagicMindRuntime, MagicMindRuntime)
.apply_on_var_node(apply_on_var_node)


+ 11
- 0
lite/src/mge/network_impl.cpp View File

@@ -129,6 +129,17 @@ void NetworkImplDft::application_config() {
loc.stream = m_nr_threads;
}
};
//! currently not set Locator type because a cambricon mgb model is a
//! cross-compnode graph
} else if (device_type == LiteDeviceType::LITE_CAMBRICON) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) {
loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
}
};
} else {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
loc = m_compnode_locator;


+ 3
- 1
src/cambricon/impl/magicmind_runtime_opr.cpp View File

@@ -14,6 +14,7 @@
#include "megbrain/comp_node_env.h"

#if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5

using namespace mgb;
using namespace opr;
@@ -168,7 +169,7 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr(
m_allocator{std::move(allocator)},
m_engine{nullptr},
m_context{nullptr},
m_model{std::move(model)},
m_model{std::move(model)},
m_current_ptr{nullptr} {
mgb_assert(
inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON,
@@ -387,6 +388,7 @@ SymbolVarArray MagicMindRuntimeOpr::make(
return make(std::move(model), std::move(cambricon_allocator), src, config);
}

#endif // CNRT_MAJOR_VERSION
#endif // MGB_CAMBRICON

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 4
- 0
src/cambricon/impl/magicmind_runtime_opr.sereg.h View File

@@ -12,6 +12,8 @@
#include "megbrain/cambricon/magicmind_runtime_opr.h"
#include "megbrain/serialization/sereg.h"

#if CNRT_MAJOR_VERSION >= 5

namespace mgb {
namespace serialization {

@@ -62,4 +64,6 @@ MGB_REG_OPR_SHALLOW_COPY(MagicMindRuntimeOpr, opr_shallow_copy_magicmind_runtime
} // namespace opr
} // namespace mgb

#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 2
- 0
src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h View File

@@ -15,6 +15,7 @@
#include "megbrain/serialization/file.h"

#if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5

#include <sstream>
#include "interface_runtime.h"
@@ -99,6 +100,7 @@ private:
} // namespace opr
} // namespace mgb

#endif // CNRT_MAJOR_VERSION
#endif // MGB_CAMBRICON

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 2
- 0
src/cambricon/test/magicmind_runtime_opr.cpp View File

@@ -17,6 +17,7 @@
#include "megbrain/test/helper.h"

#if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5

#include "megbrain/cambricon/magicmind_runtime_opr.h"

@@ -827,6 +828,7 @@ TEST(TestMagicMindRuntimeOpr, CrossCNCopy) {
MGB_ASSERT_TENSOR_NEAR(o2, o2_mm, 1e-4);
}

#endif
#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 3
- 2
src/core/impl/comp_node/cpu/comp_node.cpp View File

@@ -1097,7 +1097,8 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp
mgb_throw_if(
type != CompNode::DeviceType::CPU &&
type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS
&& type != CompNode::DeviceType::ATLAS &&
type != CompNode::DeviceType::CAMBRICON
,
MegBrainError,
"currently CPU can only wait for CPU, CUDA, ATLAS"
@@ -1116,7 +1117,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp
#else
mgb_throw(
MegBrainError,
"Cambricon comp_node used but MGB_CAMBRICON not enabled");
"Cambricon comp_node used but CAMBRICON BUILD not enabled");
#endif
}



Loading…
Cancel
Save