|
|
|
@@ -104,25 +104,21 @@ SymbolVarArray gopt::optimize_for_inference( |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
void modify_conv_policy(opr::mixin::Convolution& conv, |
|
|
|
megdnn::param::ExecutionPolicy::Strategy strategy) { |
|
|
|
void modify_conv_strategy( |
|
|
|
opr::mixin::Convolution& conv, |
|
|
|
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { |
|
|
|
auto policy = conv.execution_policy_transient(); |
|
|
|
policy.strategy = strategy; |
|
|
|
conv.set_execution_policy(policy); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
void inplace_conv_opr_profile_modifier(OperatorNodeBase& opr) { |
|
|
|
modify_conv_policy( |
|
|
|
void inplace_conv_opr_modifier( |
|
|
|
OperatorNodeBase& opr, |
|
|
|
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { |
|
|
|
modify_conv_strategy( |
|
|
|
opr.cast_final_safe<Opr>(), |
|
|
|
opr::mixin::Convolution::ExecutionPolicy::Strategy::PROFILE); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
void inplace_conv_opr_profile_cache_modifier(OperatorNodeBase& opr) { |
|
|
|
modify_conv_policy(opr.cast_final_safe<Opr>(), |
|
|
|
opr::mixin::Convolution::ExecutionPolicy::Strategy:: |
|
|
|
PROFILE_HEURISTIC); |
|
|
|
strategy); |
|
|
|
} |
|
|
|
|
|
|
|
void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, |
|
|
|
@@ -150,12 +146,20 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, |
|
|
|
cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \ |
|
|
|
cb(BatchConvBiasForward), |
|
|
|
|
|
|
|
void gopt::enable_opr_algo_profiling_inplace( |
|
|
|
const VarNodeArrayView& dest_vars) { |
|
|
|
#if MGB_ENABLE_FASTRUN |
|
|
|
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers = |
|
|
|
{ |
|
|
|
#define CONV(t) {opr::t::typeinfo(), &inplace_conv_opr_profile_modifier<opr::t>} |
|
|
|
void gopt::modify_opr_algo_strategy_inplace( |
|
|
|
const VarNodeArrayView& dest_vars, |
|
|
|
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { |
|
|
|
#if !MGB_ENABLE_FASTRUN |
|
|
|
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; |
|
|
|
if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { |
|
|
|
mgb_throw(MegBrainError, "fastrun is disabled at compile time"); |
|
|
|
} |
|
|
|
#endif |
|
|
|
const ThinHashMap<Typeinfo*, std::function<void(OperatorNodeBase&)>> |
|
|
|
modifiers = { |
|
|
|
#define CONV(t) \ |
|
|
|
{opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \ |
|
|
|
std::placeholders::_1, strategy)} |
|
|
|
MGB_FOREACH_FASTRUN_OPR(CONV) |
|
|
|
#undef CONV |
|
|
|
}; |
|
|
|
@@ -171,34 +175,23 @@ void gopt::enable_opr_algo_profiling_inplace( |
|
|
|
for (auto i : dest_vars) { |
|
|
|
dep_iter.add(i); |
|
|
|
} |
|
|
|
#else |
|
|
|
mgb_throw(MegBrainError, "fastrun is disabled at compile time"); |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
void gopt::enable_opr_use_profiling_cache_inplace( |
|
|
|
void gopt::enable_opr_algo_profiling_inplace( |
|
|
|
const VarNodeArrayView& dest_vars) { |
|
|
|
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers = |
|
|
|
{ |
|
|
|
#define CONV(t) \ |
|
|
|
{opr::t::typeinfo(), &inplace_conv_opr_profile_cache_modifier<opr::t>} |
|
|
|
MGB_FOREACH_FASTRUN_OPR(CONV) |
|
|
|
#undef CONV |
|
|
|
}; |
|
|
|
|
|
|
|
auto on_opr = [&](OperatorNodeBase* opr) { |
|
|
|
auto iter = modifiers.find(opr->dyn_typeinfo()); |
|
|
|
if (iter != modifiers.end()) { |
|
|
|
iter->second(*opr); |
|
|
|
} |
|
|
|
}; |
|
|
|
modify_opr_algo_strategy_inplace(dest_vars, |
|
|
|
opr::mixin::Convolution::ExecutionPolicy:: |
|
|
|
Strategy::PROFILE); |
|
|
|
} |
|
|
|
|
|
|
|
cg::DepOprIter dep_iter{on_opr}; |
|
|
|
for (auto i : dest_vars) { |
|
|
|
dep_iter.add(i); |
|
|
|
} |
|
|
|
void gopt::enable_opr_use_profiling_cache_inplace( |
|
|
|
const VarNodeArrayView& dest_vars) { |
|
|
|
modify_opr_algo_strategy_inplace(dest_vars, |
|
|
|
opr::mixin::Convolution::ExecutionPolicy:: |
|
|
|
Strategy::PROFILE_HEURISTIC); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void gopt::set_opr_algo_workspace_limit_inplace( |
|
|
|
const VarNodeArrayView& dest_vars, size_t workspace_limit) { |
|
|
|
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> |
|
|
|
|