|
|
|
@@ -168,7 +168,8 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr( |
|
|
|
m_allocator{std::move(allocator)}, |
|
|
|
m_engine{nullptr}, |
|
|
|
m_context{nullptr}, |
|
|
|
m_model{std::move(model)} { |
|
|
|
m_model{std::move(model)}, |
|
|
|
m_current_ptr{nullptr} { |
|
|
|
mgb_assert( |
|
|
|
inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON, |
|
|
|
"MagicMindRuntimeOpr can only be used on cambricon comp node; " |
|
|
|
@@ -230,8 +231,18 @@ void MagicMindRuntimeOpr::scn_do_execute() { |
|
|
|
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(output(i)->shape()))); |
|
|
|
MM_CHECK(tensor->SetData(output(i)->dev_tensor().raw_ptr())); |
|
|
|
} |
|
|
|
auto size = output().back()->dev_tensor().layout().span().dist_byte(); |
|
|
|
MM_CHECK(m_context->SetWorkspace(output().back()->dev_tensor().raw_ptr(), size)); |
|
|
|
if (m_current_ptr == nullptr) { |
|
|
|
auto size = output().back()->dev_tensor().layout().span().dist_byte(); |
|
|
|
m_current_ptr = output().back()->dev_tensor().raw_ptr(); |
|
|
|
MM_CHECK(m_context->SetWorkspace(m_current_ptr, size)); |
|
|
|
} else { |
|
|
|
auto current_ptr = output().back()->dev_tensor().raw_ptr(); |
|
|
|
mgb_assert( |
|
|
|
current_ptr == m_current_ptr, |
|
|
|
"workspace has been changed, the execution context should be " |
|
|
|
"reconstructed, but now this is not supported (got:%p,prev:%p)", |
|
|
|
current_ptr, m_current_ptr); |
|
|
|
} |
|
|
|
MM_CHECK(m_context->Enqueue(inputs, outputs, cnrt_env.queue)); |
|
|
|
for (auto&& i : inputs) { |
|
|
|
i->Destroy(); |
|
|
|
@@ -293,7 +304,6 @@ void MagicMindRuntimeOpr::get_output_var_shape( |
|
|
|
false, "static shape infer for MagicMindRuntimeOpr(%s) failed", |
|
|
|
cname()); |
|
|
|
} |
|
|
|
return; |
|
|
|
for (auto&& i : inputs) { |
|
|
|
i->Destroy(); |
|
|
|
} |
|
|
|
|