|
|
|
@@ -21,6 +21,10 @@ |
|
|
|
#include "megcore_opencl.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO |
|
|
|
#include "cpuinfo.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
#include <fstream> |
|
|
|
#include <memory> |
|
|
|
#include <set> |
|
|
|
@@ -42,14 +46,7 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { |
|
|
|
LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded."); |
|
|
|
m_load_result = src_impl.m_loader->load(m_load_config, true); |
|
|
|
|
|
|
|
//! flag weather the mode is cross compnode model |
|
|
|
cross_compnode_model_detect(); |
|
|
|
|
|
|
|
//! update the IO of the network |
|
|
|
update_io(); |
|
|
|
|
|
|
|
//! replace the IO when there is device input or output |
|
|
|
compile_graph(); |
|
|
|
configure_after_loaded(); |
|
|
|
} |
|
|
|
|
|
|
|
void NetworkImplDft::application_config() { |
|
|
|
@@ -364,7 +361,7 @@ void NetworkImplDft::adapt_option_valid() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void NetworkImplDft::global_layout_transform() { |
|
|
|
void NetworkImplDft::layout_transform_optimization() { |
|
|
|
if (m_set_layout_transform) { |
|
|
|
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; |
|
|
|
auto output_var_array = mgb::gopt::layout_transform( |
|
|
|
@@ -382,6 +379,103 @@ void NetworkImplDft::global_layout_transform() { |
|
|
|
for (auto&& item : m_load_result.output_var_map) { |
|
|
|
item.second = out_var_map[item.second]; |
|
|
|
} |
|
|
|
} else if (m_user_config->auto_optimize_inference) { |
|
|
|
//! set model weight preprocess |
|
|
|
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true; |
|
|
|
LITE_LOG( |
|
|
|
"weight_preprocess is enabled, this maybe use more memory when " |
|
|
|
"infernece."); |
|
|
|
//! get the current format and data type of the model |
|
|
|
bool is_model_nchw = true; |
|
|
|
//! is any convolution is int8 |
|
|
|
bool is_model_int8 = false; |
|
|
|
//! is all convolution is float32 |
|
|
|
bool is_model_float32 = true; |
|
|
|
float conv_cnt = 0; |
|
|
|
float dimshuffle_cnt = 0; |
|
|
|
|
|
|
|
auto detect_int8_model = [&](const VarNode* input) { |
|
|
|
if (input->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 || |
|
|
|
input->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm) { |
|
|
|
is_model_int8 = true; |
|
|
|
is_model_float32 = false; |
|
|
|
} else if (input->dtype().enumv() == megdnn::DTypeEnum::Float32) { |
|
|
|
is_model_float32 = (is_model_float32 && true); |
|
|
|
} else { |
|
|
|
is_model_float32 = false; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
cg::DepOprIter dep([&](cg::OperatorNodeBase* opr) { |
|
|
|
if (auto conv = opr->try_cast_final<opr::ConvolutionForward>()) { |
|
|
|
if (conv->param().format != megdnn::param::ConvBias::Format::NCHW) { |
|
|
|
is_model_nchw = false; |
|
|
|
} |
|
|
|
conv_cnt++; |
|
|
|
detect_int8_model(conv->input(0)); |
|
|
|
} else if (auto conv_bias = opr->try_cast_final<opr::ConvBias>()) { |
|
|
|
if (conv_bias->param().format != |
|
|
|
megdnn::param::ConvBias::Format::NCHW) { |
|
|
|
is_model_nchw = false; |
|
|
|
} |
|
|
|
conv_cnt++; |
|
|
|
detect_int8_model(conv->input(0)); |
|
|
|
} else if (auto dimshuffle = opr->try_cast_final<opr::Dimshuffle>()) { |
|
|
|
LITE_MARK_USED_VAR(dimshuffle); |
|
|
|
dimshuffle_cnt++; |
|
|
|
} |
|
|
|
}); |
|
|
|
for (auto&& i : m_load_result.output_var_list) |
|
|
|
dep.add(i); |
|
|
|
|
|
|
|
float radio_dimshuffle_conv = 0; |
|
|
|
if (conv_cnt > 0) { |
|
|
|
radio_dimshuffle_conv = dimshuffle_cnt / conv_cnt; |
|
|
|
} |
|
|
|
//! format optimize can only applied on nchw model, |
|
|
|
//! shufflenet like model will hurt the performance when using nchw88 or nchw44 |
|
|
|
//! format, here just heuristically decide the gate radio of |
|
|
|
//! dimshuffle and convolution |
|
|
|
if (!is_model_nchw || radio_dimshuffle_conv > 0.15f) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
//! determine the layout by the device information |
|
|
|
//! TODO: shufflenet like model use nchw88 or nchw44 will hurt the |
|
|
|
//! performance |
|
|
|
if (m_user_config->device_type == LITE_CPU) { |
|
|
|
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO |
|
|
|
cpuinfo_initialize(); |
|
|
|
//! if all convolution and matmul data type is float32 |
|
|
|
if (is_model_float32) { |
|
|
|
//! if device is x86 |
|
|
|
//! if x86 support avx, use format nchw88 |
|
|
|
if (cpuinfo_has_x86_avx()) { |
|
|
|
m_load_config.comp_graph->options().graph_opt.enable_nchw88(); |
|
|
|
LITE_LOG("Configure model inference with nchw88 format."); |
|
|
|
} else if (cpuinfo_has_x86_sse2() && !cpuinfo_has_x86_sse3()) { |
|
|
|
//! if x86 only support sse2, use format nchw44 |
|
|
|
m_load_config.comp_graph->options().graph_opt.enable_nchw44(); |
|
|
|
LITE_LOG("Configure model inference with nchw44 format."); |
|
|
|
} else if (cpuinfo_has_arm_neon()) { |
|
|
|
//! if device is arm, use format nchw44 |
|
|
|
m_load_config.comp_graph->options().graph_opt.enable_nchw44(); |
|
|
|
LITE_LOG("Configure model inference with nchw44 format."); |
|
|
|
} |
|
|
|
} else if (is_model_int8) { |
|
|
|
//! if date type of convolution is int8 |
|
|
|
//! if device is arm and support dot, use nchw44-dot format |
|
|
|
if (cpuinfo_has_arm_neon() && cpuinfo_has_arm_neon_dot()) { |
|
|
|
m_load_config.comp_graph->options().graph_opt.enable_nchw44_dot(); |
|
|
|
LITE_LOG("Configure model inference with nchw44-dot format."); |
|
|
|
} else if (cpuinfo_has_arm_neon()) { |
|
|
|
//! if device is arm and do not support dot, use nchw44 format |
|
|
|
m_load_config.comp_graph->options().graph_opt.enable_nchw44(); |
|
|
|
LITE_LOG("Configure model inference with nchw44 format."); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -422,10 +516,13 @@ void NetworkImplDft::load_model( |
|
|
|
} |
|
|
|
|
|
|
|
m_load_result = m_loader->load(m_load_config, true); |
|
|
|
configure_after_loaded(); |
|
|
|
} |
|
|
|
|
|
|
|
void NetworkImplDft::configure_after_loaded() { |
|
|
|
modify_exection_policy(); |
|
|
|
|
|
|
|
global_layout_transform(); |
|
|
|
layout_transform_optimization(); |
|
|
|
|
|
|
|
//! some optimization option maybe invalid in some case, so here just |
|
|
|
//! auto determine whether some options will apply. |
|
|
|
|