more info: 16cd674c56 * change MGE_OVERRIDE_LOG_LEVEL to RUNTIME_OVERRIDE_LOG_LEVEL * use ::std::getenv not MGB_GETENV for special ENV GitOrigin-RevId: ee0f9c0f72e627c331c00100f6a21adc927081dftags/v1.4.0
| @@ -23,6 +23,7 @@ | |||||
| #ifdef __ANDROID__ | #ifdef __ANDROID__ | ||||
| #include <android/log.h> | #include <android/log.h> | ||||
| #include <sys/system_properties.h> | |||||
| #endif | #endif | ||||
| using namespace mgb; | using namespace mgb; | ||||
| @@ -32,11 +33,19 @@ LogLevel config_default_log_level() { | |||||
| auto default_level = LogLevel::ERROR; | auto default_level = LogLevel::ERROR; | ||||
| //! env to config LogLevel | //! env to config LogLevel | ||||
| //! DEBUG = 0, INFO = 1, WARN = 2, ERROR = 3, NO_LOG = 4 | //! DEBUG = 0, INFO = 1, WARN = 2, ERROR = 3, NO_LOG = 4 | ||||
| //! for example , export MGE_OVERRIDE_LOG_LEVEL=0, means set LogLevel to | |||||
| //! for example , export RUNTIME_OVERRIDE_LOG_LEVEL=0, means set LogLevel to | |||||
| //! DEBUG | //! DEBUG | ||||
| if (auto env = MGB_GETENV("MGE_OVERRIDE_LOG_LEVEL")) | |||||
| if (auto env = ::std::getenv("RUNTIME_OVERRIDE_LOG_LEVEL")) | |||||
| default_level = static_cast<LogLevel>(std::stoi(env)); | default_level = static_cast<LogLevel>(std::stoi(env)); | ||||
| #ifdef __ANDROID__ | |||||
| //! special for Android prop, attention: getprop may need permission | |||||
| char buf[PROP_VALUE_MAX]; | |||||
| if (__system_property_get("RUNTIME_OVERRIDE_LOG_LEVEL", buf) > 0) { | |||||
| default_level = static_cast<LogLevel>(atoi(buf)); | |||||
| } | |||||
| #endif | |||||
| return default_level; | return default_level; | ||||
| } | } | ||||
| @@ -155,7 +164,7 @@ void default_log_handler(LogLevel level, | |||||
| default: | default: | ||||
| android_level = ANDROID_LOG_ERROR; | android_level = ANDROID_LOG_ERROR; | ||||
| } | } | ||||
| __android_log_vprint(android_level, "megbrain", fmt, ap); | |||||
| __android_log_vprint(android_level, "runtime", fmt, ap); | |||||
| #endif | #endif | ||||
| #undef HDR_FMT | #undef HDR_FMT | ||||
| @@ -185,7 +194,7 @@ class MegDNNLogHandler { | |||||
| return; | return; | ||||
| } | } | ||||
| std::string new_fmt{"[megdnn] "}; | |||||
| std::string new_fmt{"[dnn] "}; | |||||
| new_fmt.append(fmt); | new_fmt.append(fmt); | ||||
| log_handler(mgb_level, file, func, line, new_fmt.c_str(), ap); | log_handler(mgb_level, file, func, line, new_fmt.c_str(), ap); | ||||
| } | } | ||||
| @@ -238,9 +247,17 @@ namespace { | |||||
| #endif // MGB_ENABLE_LOGGING | #endif // MGB_ENABLE_LOGGING | ||||
| LogLevel mgb::set_log_level(LogLevel level) { | LogLevel mgb::set_log_level(LogLevel level) { | ||||
| if (auto env = MGB_GETENV("MGE_OVERRIDE_LOG_LEVEL")) | |||||
| if (auto env = ::std::getenv("RUNTIME_OVERRIDE_LOG_LEVEL")) | |||||
| level = static_cast<LogLevel>(std::stoi(env)); | level = static_cast<LogLevel>(std::stoi(env)); | ||||
| #ifdef __ANDROID__ | |||||
| //! special for Android prop, attention: getprop may need permission | |||||
| char buf[PROP_VALUE_MAX]; | |||||
| if (__system_property_get("RUNTIME_OVERRIDE_LOG_LEVEL", buf) > 0) { | |||||
| level = static_cast<LogLevel>(atoi(buf)); | |||||
| } | |||||
| #endif | |||||
| auto ret = min_log_level; | auto ret = min_log_level; | ||||
| min_log_level = level; | min_log_level = level; | ||||
| return ret; | return ret; | ||||
| @@ -256,7 +273,6 @@ LogHandler mgb::set_log_handler(LogHandler handler) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| #if MGB_ASSERT_LOC | |||||
| void mgb::__assert_fail__( | void mgb::__assert_fail__( | ||||
| const char *file, int line, const char *func, | const char *file, int line, const char *func, | ||||
| const char *expr, const char *msg_fmt, ...) { | const char *expr, const char *msg_fmt, ...) { | ||||
| @@ -273,11 +289,6 @@ void mgb::__assert_fail__( | |||||
| } | } | ||||
| mgb_throw_raw(AssertionError{msg}); | mgb_throw_raw(AssertionError{msg}); | ||||
| } | } | ||||
| #else | |||||
| void mgb::__assert_fail__() { | |||||
| mgb_throw(AssertionError, "assertion failed"); | |||||
| } | |||||
| #endif | |||||
| #if MGB_ENABLE_LOGGING && !MGB_ENABLE_EXCEPTION | #if MGB_ENABLE_LOGGING && !MGB_ENABLE_EXCEPTION | ||||
| void mgb::__on_exception_throw__(const std::exception &exc) { | void mgb::__on_exception_throw__(const std::exception &exc) { | ||||
| @@ -759,7 +759,7 @@ public: | |||||
| #else | #else | ||||
| mgb_throw(MegBrainError, | mgb_throw(MegBrainError, | ||||
| "Atlas comp_node used but " | "Atlas comp_node used but " | ||||
| "MGB_ATLAS not enabled"); | |||||
| "ATLAS BUILD not enabled"); | |||||
| #endif | #endif | ||||
| } else if (dest_impl->env().property().type == | } else if (dest_impl->env().property().type == | ||||
| DeviceType::CAMBRICON) { | DeviceType::CAMBRICON) { | ||||
| @@ -769,7 +769,7 @@ public: | |||||
| #else | #else | ||||
| mgb_throw(MegBrainError, | mgb_throw(MegBrainError, | ||||
| "Cambricon comp_node used but " | "Cambricon comp_node used but " | ||||
| "MGB_CAMBRICON not enabled"); | |||||
| "CAMBRICON BUILD not enabled"); | |||||
| #endif | #endif | ||||
| } | } | ||||
| else { | else { | ||||
| @@ -1035,7 +1035,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( | |||||
| return m_comp_node_impl->sync(); | return m_comp_node_impl->sync(); | ||||
| #else | #else | ||||
| mgb_throw(MegBrainError, | mgb_throw(MegBrainError, | ||||
| "Atlas comp_node used but MGB_ATLAS not enabled"); | |||||
| "Atlas comp_node used but ATLAS BUILD not enabled"); | |||||
| #endif | #endif | ||||
| } else if (cn_impl->env().property().type == | } else if (cn_impl->env().property().type == | ||||
| CompNode::DeviceType::CAMBRICON) { | CompNode::DeviceType::CAMBRICON) { | ||||
| @@ -51,14 +51,13 @@ MegDNNHandle& MegDNNHandle::get(const CompNodeEnv& env) { | |||||
| MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) { | MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) { | ||||
| auto megdnn_version = megdnn::get_version(); | auto megdnn_version = megdnn::get_version(); | ||||
| mgb_throw_if( | |||||
| megdnn_version.major != MEGDNN_MAJOR || | |||||
| megdnn_version.minor < MEGDNN_MINOR, | |||||
| SystemError, | |||||
| "incompatible megdnn version: compiled with %d.%d, get %d.%d.%d " | |||||
| "at runtime", | |||||
| MEGDNN_MAJOR, MEGDNN_MINOR, megdnn_version.major, | |||||
| megdnn_version.minor, megdnn_version.patch); | |||||
| mgb_throw_if(megdnn_version.major != MEGDNN_MAJOR || | |||||
| megdnn_version.minor < MEGDNN_MINOR, | |||||
| SystemError, | |||||
| "incompatible dnn version: compiled with %d.%d, get %d.%d.%d " | |||||
| "at runtime", | |||||
| MEGDNN_MAJOR, MEGDNN_MINOR, megdnn_version.major, | |||||
| megdnn_version.minor, megdnn_version.patch); | |||||
| bool init = false; | bool init = false; | ||||
| #if MGB_CUDA | #if MGB_CUDA | ||||
| if (env.property().type == CompNode::DeviceType::CUDA) { | if (env.property().type == CompNode::DeviceType::CUDA) { | ||||
| @@ -880,7 +880,7 @@ std::string ComputingGraphImpl::get_mem_allocation_info() const { | |||||
| return objlist->to_string(); | return objlist->to_string(); | ||||
| #endif // MGB_ENABLE_JSON | #endif // MGB_ENABLE_JSON | ||||
| mgb_log_warn("mgb is not configured with MGB_ENABLE_JSON on," | |||||
| mgb_log_warn("target is not configured with JSON BUILD on," | |||||
| "get_mem_allocation_info returns null string"); | "get_mem_allocation_info returns null string"); | ||||
| return std::string(); | return std::string(); | ||||
| } | } | ||||
| @@ -619,7 +619,7 @@ void ComputingGraphImpl::MegDNNDtorCheck::enable() { | |||||
| mgb_assert(!m_enabled); | mgb_assert(!m_enabled); | ||||
| m_enabled = true; | m_enabled = true; | ||||
| auto cb_dnn = [](megdnn::OperatorBase* opr) { | auto cb_dnn = [](megdnn::OperatorBase* opr) { | ||||
| mgb_log_error("unexpected destruction of megdnn opr %p", opr); | |||||
| mgb_log_error("unexpected destruction of dnn opr %p", opr); | |||||
| mgb_trap(); | mgb_trap(); | ||||
| }; | }; | ||||
| auto cb_mem = [](size_t alloc_size, bool, void* ptr) { | auto cb_mem = [](size_t alloc_size, bool, void* ptr) { | ||||
| @@ -108,34 +108,33 @@ void __on_exception_throw__(const std::exception &exc) | |||||
| } while(0) | } while(0) | ||||
| // assert | // assert | ||||
| void __assert_fail__(const char* file, int line, const char* func, | |||||
| const char* expr, const char* msg_fmt = 0, ...) | |||||
| __attribute__((format(printf, 5, 6), noreturn)); | |||||
| #if MGB_ASSERT_LOC | #if MGB_ASSERT_LOC | ||||
| /*! | /*! | ||||
| * \brief extended assert | * \brief extended assert | ||||
| * extra diagnostics message (in printf format) could be printed when assertion | * extra diagnostics message (in printf format) could be printed when assertion | ||||
| * fails; the asserted expression is guaranteed to be evaluated | * fails; the asserted expression is guaranteed to be evaluated | ||||
| */ | */ | ||||
| #define mgb_assert(expr, msg...) \ | |||||
| do { \ | |||||
| if (mgb_unlikely(!(expr))) \ | |||||
| ::mgb::__assert_fail__(__FILE__, __LINE__, \ | |||||
| __PRETTY_FUNCTION__, # expr, ##msg); \ | |||||
| } while(0) | |||||
| void __assert_fail__( | |||||
| const char *file, int line, const char *func, | |||||
| const char *expr, const char *msg_fmt = 0, ...) | |||||
| __attribute__((format(printf, 5, 6), noreturn)); | |||||
| #define mgb_assert(expr, msg...) \ | |||||
| do { \ | |||||
| if (mgb_unlikely(!(expr))) \ | |||||
| ::mgb::__assert_fail__(__FILE__, __LINE__, __PRETTY_FUNCTION__, \ | |||||
| #expr, ##msg); \ | |||||
| } while (0) | |||||
| #else | #else | ||||
| #define mgb_assert(expr, msg...) \ | |||||
| do { \ | |||||
| if (mgb_unlikely(!(expr))) \ | |||||
| ::mgb::__assert_fail__(); \ | |||||
| } while(0) | |||||
| void __assert_fail__() __attribute__((noreturn)); | |||||
| #endif // MGB_ASSERT_LOC | |||||
| #define mgb_assert(expr, msg...) \ | |||||
| do { \ | |||||
| if (mgb_unlikely(!(expr))) \ | |||||
| ::mgb::__assert_fail__( \ | |||||
| "about location info, please build with debug", __LINE__, \ | |||||
| NULL, #expr, ##msg); \ | |||||
| } while (0) | |||||
| #endif // MGB_ASSERT_LOC | |||||
| /* ================ logging ================ */ | /* ================ logging ================ */ | ||||
| //! caused by need remve some words at opt release | |||||
| #if MGB_ENABLE_LOGGING | |||||
| #if MGB_ASSERT_LOC | |||||
| #define mgb_log_debug(fmt...) \ | #define mgb_log_debug(fmt...) \ | ||||
| _mgb_do_log(::mgb::LogLevel::DEBUG, __FILE__, __func__, __LINE__, fmt) | _mgb_do_log(::mgb::LogLevel::DEBUG, __FILE__, __func__, __LINE__, fmt) | ||||
| #define mgb_log(fmt...) \ | #define mgb_log(fmt...) \ | ||||
| @@ -154,7 +153,6 @@ void __assert_fail__() __attribute__((noreturn)); | |||||
| _mgb_do_log(::mgb::LogLevel::WARN, "", "", __LINE__, fmt) | _mgb_do_log(::mgb::LogLevel::WARN, "", "", __LINE__, fmt) | ||||
| #define mgb_log_error(fmt...) \ | #define mgb_log_error(fmt...) \ | ||||
| _mgb_do_log(::mgb::LogLevel::ERROR, LOC, "", __LINE__, fmt) | _mgb_do_log(::mgb::LogLevel::ERROR, LOC, "", __LINE__, fmt) | ||||
| #undef LOC | |||||
| #endif | #endif | ||||
| enum class LogLevel { DEBUG, INFO, WARN, ERROR, NO_LOG }; | enum class LogLevel { DEBUG, INFO, WARN, ERROR, NO_LOG }; | ||||
| @@ -1045,7 +1045,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| INTER_WEIGHT_DENSEI_DOT; | INTER_WEIGHT_DENSEI_DOT; | ||||
| return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI; | return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_DENSEI; | ||||
| } else { | } else { | ||||
| mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "mode error"); | |||||
| if (filter->shape()[1] == 1 && filter->shape()[2] == 1) { | if (filter->shape()[1] == 1 && filter->shape()[2] == 1) { | ||||
| return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI; | return megdnn::param::RelayoutFormat::Mode::INTER_WEIGHT_CHANI; | ||||
| } else { | } else { | ||||
| @@ -1081,9 +1082,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | ||||
| mgb_assert(conv_opr.param().format == | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| mgb_throw_if( | |||||
| conv_opr.param().format != | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| VarNode *conv_src = nullptr, *conv_weights = nullptr; | VarNode *conv_src = nullptr, *conv_weights = nullptr; | ||||
| if (new_inp[0]->shape().ndim == 4) { | if (new_inp[0]->shape().ndim == 4) { | ||||
| // new input src is NCHW | // new input src is NCHW | ||||
| @@ -1094,8 +1097,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| icpg = new_inp[1]->shape()[1]; | icpg = new_inp[1]->shape()[1]; | ||||
| ocpg = new_inp[1]->shape()[0]; | ocpg = new_inp[1]->shape()[0]; | ||||
| } else { | } else { | ||||
| mgb_assert(conv_opr.param().sparse == | |||||
| megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(conv_opr.param().sparse != | |||||
| megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "ERROR mode"); | |||||
| group = new_inp[1]->shape()[0]; | group = new_inp[1]->shape()[0]; | ||||
| icpg = new_inp[1]->shape()[2]; | icpg = new_inp[1]->shape()[2]; | ||||
| ocpg = new_inp[1]->shape()[1]; | ocpg = new_inp[1]->shape()[1]; | ||||
| @@ -1117,8 +1121,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| megdnn::param::Convolution::Sparse::DENSE) { | megdnn::param::Convolution::Sparse::DENSE) { | ||||
| ocpg = new_inp[1]->shape()[0]; | ocpg = new_inp[1]->shape()[0]; | ||||
| } else { | } else { | ||||
| mgb_assert(conv_opr.param().sparse == | |||||
| megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(conv_opr.param().sparse != | |||||
| megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "ERROR mode"); | |||||
| size_t icpg = new_inp[1]->shape()[2]; | size_t icpg = new_inp[1]->shape()[2]; | ||||
| ocpg = new_inp[1]->shape()[1]; | ocpg = new_inp[1]->shape()[1]; | ||||
| if (icpg == 1 && ocpg == 1) { | if (icpg == 1 && ocpg == 1) { | ||||
| @@ -1176,9 +1181,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | ||||
| mgb_assert(conv_bias_opr.param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| mgb_throw_if( | |||||
| conv_bias_opr.param().format != | |||||
| megdnn::param::ConvBias::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr, | VarNode *conv_bias_src = nullptr, *conv_bias_weights = nullptr, | ||||
| *conv_bias_bias = nullptr; | *conv_bias_bias = nullptr; | ||||
| if (new_inp[0]->shape().ndim == 4) { | if (new_inp[0]->shape().ndim == 4) { | ||||
| @@ -1190,8 +1197,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| icpg = new_inp[1]->shape()[1]; | icpg = new_inp[1]->shape()[1]; | ||||
| ocpg = new_inp[1]->shape()[0]; | ocpg = new_inp[1]->shape()[0]; | ||||
| } else { | } else { | ||||
| mgb_assert(conv_bias_opr.param().sparse == | |||||
| megdnn::param::ConvBias::Sparse::GROUP); | |||||
| mgb_throw_if(conv_bias_opr.param().sparse != | |||||
| megdnn::param::ConvBias::Sparse::GROUP, | |||||
| MegBrainError, "mode error"); | |||||
| group = new_inp[1]->shape()[0]; | group = new_inp[1]->shape()[0]; | ||||
| icpg = new_inp[1]->shape()[2]; | icpg = new_inp[1]->shape()[2]; | ||||
| ocpg = new_inp[1]->shape()[1]; | ocpg = new_inp[1]->shape()[1]; | ||||
| @@ -1213,8 +1221,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| megdnn::param::ConvBias::Sparse::DENSE) { | megdnn::param::ConvBias::Sparse::DENSE) { | ||||
| ocpg = new_inp[1]->shape()[0]; | ocpg = new_inp[1]->shape()[0]; | ||||
| } else { | } else { | ||||
| mgb_assert(conv_bias_opr.param().sparse == | |||||
| megdnn::param::ConvBias::Sparse::GROUP); | |||||
| mgb_throw_if(conv_bias_opr.param().sparse != | |||||
| megdnn::param::ConvBias::Sparse::GROUP, | |||||
| MegBrainError, "ERROR mode"); | |||||
| size_t icpg = new_inp[1]->shape()[2]; | size_t icpg = new_inp[1]->shape()[2]; | ||||
| ocpg = new_inp[1]->shape()[1]; | ocpg = new_inp[1]->shape()[1]; | ||||
| if (icpg == 1 && ocpg == 1) { | if (icpg == 1 && ocpg == 1) { | ||||
| @@ -1293,9 +1302,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>(); | auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>(); | ||||
| mgb_assert(deconv_opr.param().format == | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| mgb_throw_if( | |||||
| deconv_opr.param().format != | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| VarNode *deconv_src = nullptr, *deconv_weights = nullptr; | VarNode *deconv_src = nullptr, *deconv_weights = nullptr; | ||||
| if (new_inp[1]->shape().ndim == 4) { | if (new_inp[1]->shape().ndim == 4) { | ||||
| // new input src is NCHW | // new input src is NCHW | ||||
| @@ -1306,8 +1317,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| icpg = new_inp[0]->shape()[0]; | icpg = new_inp[0]->shape()[0]; | ||||
| ocpg = new_inp[0]->shape()[1]; | ocpg = new_inp[0]->shape()[1]; | ||||
| } else { | } else { | ||||
| mgb_assert(deconv_opr.param().sparse == | |||||
| megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(deconv_opr.param().sparse != | |||||
| megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "mode error"); | |||||
| group = new_inp[0]->shape()[0]; | group = new_inp[0]->shape()[0]; | ||||
| icpg = new_inp[0]->shape()[1]; | icpg = new_inp[0]->shape()[1]; | ||||
| ocpg = new_inp[0]->shape()[2]; | ocpg = new_inp[0]->shape()[2]; | ||||
| @@ -1329,8 +1341,9 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| megdnn::param::Convolution::Sparse::DENSE) { | megdnn::param::Convolution::Sparse::DENSE) { | ||||
| ocpg = new_inp[0]->shape()[1]; | ocpg = new_inp[0]->shape()[1]; | ||||
| } else { | } else { | ||||
| mgb_assert(deconv_opr.param().sparse == | |||||
| megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(deconv_opr.param().sparse != | |||||
| megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "mode error"); | |||||
| ocpg = new_inp[0]->shape()[2]; | ocpg = new_inp[0]->shape()[2]; | ||||
| } | } | ||||
| @@ -1393,9 +1406,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| return opr_shallow_copy; | return opr_shallow_copy; | ||||
| } | } | ||||
| auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>(); | auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>(); | ||||
| mgb_assert(resize_opr.param().format == | |||||
| megdnn::param::Resize::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| mgb_throw_if( | |||||
| resize_opr.param().format != | |||||
| megdnn::param::Resize::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| VarNode* inp = nullptr; | VarNode* inp = nullptr; | ||||
| if (new_inp[0]->shape().ndim == 4) { | if (new_inp[0]->shape().ndim == 4) { | ||||
| auto param = megdnn::param::RelayoutFormat(); | auto param = megdnn::param::RelayoutFormat(); | ||||
| @@ -1425,9 +1440,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| return opr_shallow_copy; | return opr_shallow_copy; | ||||
| } | } | ||||
| auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>(); | auto& warp_opr = opr->cast_final_safe<opr::WarpPerspectiveForward>(); | ||||
| mgb_assert(warp_opr.param().format == | |||||
| megdnn::param::WarpPerspective::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| mgb_throw_if( | |||||
| warp_opr.param().format != | |||||
| megdnn::param::WarpPerspective::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| VarNode* inp = nullptr; | VarNode* inp = nullptr; | ||||
| if (new_inp[0]->shape().ndim == 4) { | if (new_inp[0]->shape().ndim == 4) { | ||||
| // new input src is NCHW | // new input src is NCHW | ||||
| @@ -1466,9 +1483,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| return opr_shallow_copy; | return opr_shallow_copy; | ||||
| } | } | ||||
| auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>(); | auto& warp_opr = opr->cast_final_safe<opr::WarpAffineForward>(); | ||||
| mgb_assert(warp_opr.param().format == | |||||
| megdnn::param::WarpAffine::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| mgb_throw_if( | |||||
| warp_opr.param().format != | |||||
| megdnn::param::WarpAffine::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| VarNode* inp = nullptr; | VarNode* inp = nullptr; | ||||
| if (new_inp[0]->shape().ndim == 4) { | if (new_inp[0]->shape().ndim == 4) { | ||||
| // new input src is NCHW | // new input src is NCHW | ||||
| @@ -1499,9 +1518,11 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| return opr_shallow_copy; | return opr_shallow_copy; | ||||
| } | } | ||||
| auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); | auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); | ||||
| mgb_assert(pooling_opr.param().format == | |||||
| megdnn::param::Pooling::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| mgb_throw_if( | |||||
| pooling_opr.param().format != | |||||
| megdnn::param::Pooling::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NHWCD4"); | |||||
| VarNode* inp = nullptr; | VarNode* inp = nullptr; | ||||
| if (new_inp[0]->shape().ndim == 4) { | if (new_inp[0]->shape().ndim == 4) { | ||||
| // new input src is NCHW | // new input src is NCHW | ||||
| @@ -1465,7 +1465,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; | return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; | ||||
| } | } | ||||
| } else { | } else { | ||||
| mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "mode error"); | |||||
| mgb_assert(filter->shape().ndim == 5, | mgb_assert(filter->shape().ndim == 5, | ||||
| "The origin filter if not NCHW mode"); | "The origin filter if not NCHW mode"); | ||||
| size_t IC = filter->shape()[2]; | size_t IC = filter->shape()[2]; | ||||
| @@ -2018,7 +2019,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| ret.second = hybrid_nchw_nchwxx; | ret.second = hybrid_nchw_nchwxx; | ||||
| } | } | ||||
| } else { | } else { | ||||
| mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "mode error"); | |||||
| size_t group = filter->shape()[0]; | size_t group = filter->shape()[0]; | ||||
| size_t ocpg = filter->shape()[1]; | size_t ocpg = filter->shape()[1]; | ||||
| size_t icpg = filter->shape()[2]; | size_t icpg = filter->shape()[2]; | ||||
| @@ -2038,9 +2040,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | ||||
| mgb_assert(conv_opr.param().format == | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||||
| mgb_throw_if( | |||||
| conv_opr.param().format != | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||||
| bool valid_nchw_nchw44 = | bool valid_nchw_nchw44 = | ||||
| nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size); | nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size); | ||||
| auto is_trans = test_trans_nchwxx( | auto is_trans = test_trans_nchwxx( | ||||
| @@ -2118,9 +2122,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| mgb_assert(opr->input().size() <= 3, | mgb_assert(opr->input().size() <= 3, | ||||
| "nchwxx does not support conv_bias fuse Z right now"); | "nchwxx does not support conv_bias fuse Z right now"); | ||||
| auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | ||||
| mgb_assert(conv_bias_opr.param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||||
| mgb_throw_if( | |||||
| conv_bias_opr.param().format != | |||||
| megdnn::param::ConvBias::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||||
| bool valid_nchw_nchw44 = | bool valid_nchw_nchw44 = | ||||
| nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | ||||
| conv_bias_opr.param().nonlineMode); | conv_bias_opr.param().nonlineMode); | ||||
| @@ -2244,9 +2250,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); | auto& pooling_opr = opr->cast_final_safe<opr::PoolingForward>(); | ||||
| mgb_assert(pooling_opr.param().format == | |||||
| megdnn::param::Pooling::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWxx"); | |||||
| mgb_throw_if( | |||||
| pooling_opr.param().format != | |||||
| megdnn::param::Pooling::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWxx"); | |||||
| VarNode* inp = new_inp[0]; | VarNode* inp = new_inp[0]; | ||||
| //! if input is nchwxx | //! if input is nchwxx | ||||
| if (inp->shape().ndim == 5) { | if (inp->shape().ndim == 5) { | ||||
| @@ -2433,7 +2441,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | |||||
| mgb_throw_if(conv_mode != megdnn::param::Convolution::Sparse::GROUP, | |||||
| MegBrainError, "mode error"); | |||||
| size_t group = filter->shape()[0]; | size_t group = filter->shape()[0]; | ||||
| size_t ocpg = filter->shape()[1]; | size_t ocpg = filter->shape()[1]; | ||||
| size_t icpg = filter->shape()[2]; | size_t icpg = filter->shape()[2]; | ||||
| @@ -2462,10 +2471,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
| const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
| mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
| auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | ||||
| mgb_assert(conv_opr.param().format == | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to " | |||||
| "NCHW44_DOT"); | |||||
| mgb_throw_if(conv_opr.param().format != | |||||
| megdnn::param::Convolution::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to " | |||||
| "NCHW44_DOT"); | |||||
| bool valid_nchw_nchw44 = nchw_nchwxx_valid( | bool valid_nchw_nchw44 = nchw_nchwxx_valid( | ||||
| conv_opr, new_inp, pack_c_size, | conv_opr, new_inp, pack_c_size, | ||||
| megdnn::param::ConvBias::NonlineMode::IDENTITY, true); | megdnn::param::ConvBias::NonlineMode::IDENTITY, true); | ||||
| @@ -2543,9 +2553,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
| mgb_assert(opr->input().size() <= 3, | mgb_assert(opr->input().size() <= 3, | ||||
| "nchwxx-dot does not support conv_bias fuse Z right now"); | "nchwxx-dot does not support conv_bias fuse Z right now"); | ||||
| auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | ||||
| mgb_assert(conv_bias_opr.param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||||
| mgb_throw_if( | |||||
| conv_bias_opr.param().format != | |||||
| megdnn::param::ConvBias::Format::NCHW, | |||||
| MegBrainError, | |||||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||||
| bool valid_nchw_nchw44 = | bool valid_nchw_nchw44 = | ||||
| nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | ||||
| conv_bias_opr.param().nonlineMode, true); | conv_bias_opr.param().nonlineMode, true); | ||||
| @@ -127,7 +127,7 @@ | |||||
| // whether to enbale configuing megbrain internals through env vars | // whether to enbale configuing megbrain internals through env vars | ||||
| #ifndef MGB_ENABLE_GETENV | #ifndef MGB_ENABLE_GETENV | ||||
| #define MGB_ENABLE_GETENV 1 | |||||
| #define MGB_ENABLE_GETENV MGB_ASSERT_LOC | |||||
| #endif | #endif | ||||
| // whether to remove unnecessary features when used for serving | // whether to remove unnecessary features when used for serving | ||||
| @@ -343,24 +343,24 @@ void Elemwise::mem_plan_fwd_in2out_writable() { | |||||
| } | } | ||||
| void Elemwise::scn_do_execute() { | void Elemwise::scn_do_execute() { | ||||
| auto &&inp = input(); | |||||
| megdnn::TensorNDArray megdnn_inp; | |||||
| mgb_assert(megdnn_inp.capacity() >= inp.size(), | |||||
| "heap allocation in elemwise exec"); | |||||
| megdnn_inp.resize(inp.size()); | |||||
| for (size_t i = 0; i < inp.size(); ++ i) { | |||||
| auto&& inp = input(); | |||||
| megdnn::TensorNDArray dnn_inp; | |||||
| mgb_assert(dnn_inp.capacity() >= inp.size(), | |||||
| "heap allocation in elemwise exec"); | |||||
| dnn_inp.resize(inp.size()); | |||||
| for (size_t i = 0; i < inp.size(); ++i) { | |||||
| if (inp[i]->dev_tensor().empty()) { | if (inp[i]->dev_tensor().empty()) { | ||||
| mgb_assert(output(0)->dev_tensor().empty()); | mgb_assert(output(0)->dev_tensor().empty()); | ||||
| return; | return; | ||||
| } | } | ||||
| megdnn_inp[i] = (inp[i]->dev_tensor().as_megdnn()); | |||||
| dnn_inp[i] = (inp[i]->dev_tensor().as_megdnn()); | |||||
| } | } | ||||
| mgb_assert(!output(0)->dev_tensor().empty()); | mgb_assert(!output(0)->dev_tensor().empty()); | ||||
| megdnn_opr()->param() = param(); | megdnn_opr()->param() = param(); | ||||
| call_megdnn_opr_exec( | |||||
| comp_node(), megdnn_inp, output(0)->dev_tensor().as_megdnn(), | |||||
| megdnn_opr(), this); | |||||
| call_megdnn_opr_exec(comp_node(), dnn_inp, | |||||
| output(0)->dev_tensor().as_megdnn(), megdnn_opr(), | |||||
| this); | |||||
| } | } | ||||
| void Elemwise::init_output_static_infer_desc() { | void Elemwise::init_output_static_infer_desc() { | ||||
| @@ -126,10 +126,11 @@ namespace serialization { | |||||
| MGB_MARK_USED_VAR(graph); | MGB_MARK_USED_VAR(graph); | ||||
| SymbolVar target_shape; | SymbolVar target_shape; | ||||
| if (inputs.size() == 1) { | if (inputs.size() == 1) { | ||||
| mgb_assert(param.axis >= | |||||
| -megdnn::param::OptionalAxisV1::MAX_NDIM && | |||||
| param.axis < | |||||
| megdnn::param::OptionalAxisV1::MAX_NDIM); | |||||
| mgb_throw_if( | |||||
| param.axis < -megdnn::param::OptionalAxisV1::MAX_NDIM || | |||||
| param.axis >= | |||||
| megdnn::param::OptionalAxisV1::MAX_NDIM, | |||||
| MegBrainError, "DIM error"); | |||||
| } else { | } else { | ||||
| mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
| target_shape = inputs[1]; | target_shape = inputs[1]; | ||||
| @@ -470,9 +470,9 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD); | |||||
| SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : | SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : | ||||
| Super(OperatorNodeBaseCtorParam{src->owner_graph(), | Super(OperatorNodeBaseCtorParam{src->owner_graph(), | ||||
| config, "svd", {src}}) { | config, "svd", {src}}) { | ||||
| mgb_assert(src->dtype() == megdnn::dtype::Float32(), | |||||
| "Singular Value Decomposition on non-float32 tensors is " | |||||
| "not supoorted."); | |||||
| mgb_throw_if(src->dtype() != megdnn::dtype::Float32(), MegDNNError, | |||||
| "Singular Value Decomposition on non-float32 tensors is not " | |||||
| "supoorted."); | |||||
| init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
| add_input({src}); | add_input({src}); | ||||
| @@ -187,12 +187,12 @@ template<class Opr> | |||||
| Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr( | Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr( | ||||
| cg::SingleCNOperatorNodeBase& self) { | cg::SingleCNOperatorNodeBase& self) { | ||||
| auto comp_node = self.comp_node(); | auto comp_node = self.comp_node(); | ||||
| if (!m_megdnn_opr || m_megdnn_opr.comp_node() != comp_node) { | |||||
| m_megdnn_opr = intl::create_megdnn_opr<Opr>(comp_node); | |||||
| m_megdnn_opr->set_error_tracker( | |||||
| if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node) { | |||||
| m_dnn_opr = intl::create_megdnn_opr<Opr>(comp_node); | |||||
| m_dnn_opr->set_error_tracker( | |||||
| static_cast<cg::OperatorNodeBase*>(&self)); | static_cast<cg::OperatorNodeBase*>(&self)); | ||||
| } | } | ||||
| return *m_megdnn_opr; | |||||
| return *m_dnn_opr; | |||||
| } | } | ||||
| template<class Opr> | template<class Opr> | ||||
| @@ -228,7 +228,7 @@ template <class Opr> | |||||
| void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr( | void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr( | ||||
| mgb::cg::GraphExecutable::ExecDependencyArray& deps) { | mgb::cg::GraphExecutable::ExecDependencyArray& deps) { | ||||
| deps.emplace_back( | deps.emplace_back( | ||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | |||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr))); | |||||
| } | } | ||||
| /* ==================== MultiAxisVecFancyIndexingHelper ==================== */ | /* ==================== MultiAxisVecFancyIndexingHelper ==================== */ | ||||
| @@ -258,14 +258,24 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc( | |||||
| } | } | ||||
| } | } | ||||
| if (all_scalar) { | if (all_scalar) { | ||||
| mgb_log_warn("%s{%s}: no vector indexer; consider using Subtensor " | |||||
| #if MGB_ENABLE_GETENV | |||||
| mgb_log_warn( | |||||
| "%s{%s}: no vector indexer; consider using Subtensor " | |||||
| "family for better performance; you can set " | "family for better performance; you can set " | ||||
| "MGB_THROW_ON_SCALAR_IDX to throw an exception to help " | "MGB_THROW_ON_SCALAR_IDX to throw an exception to help " | ||||
| "tracking the related operator", | "tracking the related operator", | ||||
| cname(), dyn_typeinfo()->name); | cname(), dyn_typeinfo()->name); | ||||
| mgb_throw_if(MGB_GETENV("MGB_THROW_ON_SCALAR_IDX"), | |||||
| MegBrainError, "vector-indexing operator used with all " | |||||
| "scalar indices"); | |||||
| #else | |||||
| mgb_log_warn( | |||||
| "%s{%s}: no vector indexer; consider using Subtensor " | |||||
| "family for better performance", | |||||
| cname(), dyn_typeinfo()->name); | |||||
| #endif | |||||
| #if MGB_ENABLE_GETENV | |||||
| mgb_throw_if(MGB_GETENV("MGB_THROW_ON_SCALAR_IDX"), MegBrainError, | |||||
| "vector-indexing operator used with all " | |||||
| "scalar indices"); | |||||
| #endif | |||||
| } | } | ||||
| // always set m_scalar_idx_warn_printed to be true, so we do not print | // always set m_scalar_idx_warn_printed to be true, so we do not print | ||||
| @@ -377,21 +377,21 @@ MegDNNOprHolder::~MegDNNOprHolder() noexcept = default; | |||||
| void MegDNNOprHolder::mixin_init_output_comp_node(OperatorNodeBase &self) { | void MegDNNOprHolder::mixin_init_output_comp_node(OperatorNodeBase &self) { | ||||
| SingleCNOperatorNode::mixin_init_output_comp_node(self); | SingleCNOperatorNode::mixin_init_output_comp_node(self); | ||||
| create_megdnn_opr(); | create_megdnn_opr(); | ||||
| mgb_assert(m_megdnn_opr); | |||||
| m_megdnn_opr->set_error_tracker(&self); | |||||
| mgb_assert(m_dnn_opr); | |||||
| m_dnn_opr->set_error_tracker(&self); | |||||
| } | } | ||||
| void MegDNNOprHolder::mixin_on_output_comp_node_stream_changed( | void MegDNNOprHolder::mixin_on_output_comp_node_stream_changed( | ||||
| OperatorNodeBase &self) { | OperatorNodeBase &self) { | ||||
| SingleCNOperatorNode::mixin_on_output_comp_node_stream_changed(self); | SingleCNOperatorNode::mixin_on_output_comp_node_stream_changed(self); | ||||
| create_megdnn_opr(); | create_megdnn_opr(); | ||||
| mgb_assert(m_megdnn_opr); | |||||
| m_megdnn_opr->set_error_tracker(&self); | |||||
| mgb_assert(m_dnn_opr); | |||||
| m_dnn_opr->set_error_tracker(&self); | |||||
| } | } | ||||
| void MegDNNOprHolder::set_megdnn_opr( | void MegDNNOprHolder::set_megdnn_opr( | ||||
| std::unique_ptr<megdnn::OperatorBase> self) { | std::unique_ptr<megdnn::OperatorBase> self) { | ||||
| m_megdnn_opr = std::move(self); | |||||
| m_dnn_opr = std::move(self); | |||||
| } | } | ||||
| void MegDNNOprHolder::record_megdnn_opr( | void MegDNNOprHolder::record_megdnn_opr( | ||||
| @@ -402,7 +402,7 @@ void MegDNNOprHolder::record_megdnn_opr( | |||||
| void MegDNNOprHolder::record_megdnn_opr( | void MegDNNOprHolder::record_megdnn_opr( | ||||
| cg::GraphExecutable::ExecDependencyArray& deps) { | cg::GraphExecutable::ExecDependencyArray& deps) { | ||||
| record_megdnn_opr(std::move(m_megdnn_opr), deps); | |||||
| record_megdnn_opr(std::move(m_dnn_opr), deps); | |||||
| } | } | ||||
| /* ================== MegDNNOprHolderBwdStaticInfer ================== */ | /* ================== MegDNNOprHolderBwdStaticInfer ================== */ | ||||
| @@ -59,10 +59,10 @@ cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const { | |||||
| } | } | ||||
| void RNGOprBase::ensure_megdnn_opr() { | void RNGOprBase::ensure_megdnn_opr() { | ||||
| if (!m_megdnn_opr || m_megdnn_opr.comp_node() != comp_node()) { | |||||
| if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) { | |||||
| // activate comp_node for curandCreateGenerator in create_megdnn_opr | // activate comp_node for curandCreateGenerator in create_megdnn_opr | ||||
| comp_node().activate(); | comp_node().activate(); | ||||
| m_megdnn_opr = create_megdnn_opr(); | |||||
| m_dnn_opr = create_megdnn_opr(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -76,7 +76,7 @@ void RNGOprBase::init_output_static_infer_desc() { | |||||
| auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { | auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { | ||||
| ensure_megdnn_opr(); | ensure_megdnn_opr(); | ||||
| dest.ndim = 1; | dest.ndim = 1; | ||||
| dest.shape[0] = m_megdnn_opr->get_workspace_in_bytes( | |||||
| dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( | |||||
| {inp.val.at(0).shape(), output(0)->dtype()}); | {inp.val.at(0).shape(), output(0)->dtype()}); | ||||
| return true; | return true; | ||||
| }; | }; | ||||
| @@ -87,7 +87,7 @@ void RNGOprBase::init_output_static_infer_desc() { | |||||
| } | } | ||||
| void RNGOprBase::scn_do_execute() { | void RNGOprBase::scn_do_execute() { | ||||
| m_megdnn_opr->exec( | |||||
| m_dnn_opr->exec( | |||||
| output(0)->dev_tensor().as_megdnn(), | output(0)->dev_tensor().as_megdnn(), | ||||
| get_megdnn_workspace_from_var(output(1))); | get_megdnn_workspace_from_var(output(1))); | ||||
| } | } | ||||
| @@ -332,7 +332,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::AlgoChooserHelper( | |||||
| const megdnn::param::ExecutionPolicy& execution_policy, | const megdnn::param::ExecutionPolicy& execution_policy, | ||||
| bool allow_weight_preprocess) | bool allow_weight_preprocess) | ||||
| : m_layouts{layouts}, | : m_layouts{layouts}, | ||||
| m_megdnn_opr{megdnn_opr}, | |||||
| m_dnn_opr{megdnn_opr}, | |||||
| m_param{param_str}, | m_param{param_str}, | ||||
| m_base_mgb_opr{mgb_opr}, | m_base_mgb_opr{mgb_opr}, | ||||
| m_cn{cn}, | m_cn{cn}, | ||||
| @@ -356,15 +356,15 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_heuristic( | |||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| auto attr = extract_algo_attribute(selected_strategy); | auto attr = extract_algo_attribute(selected_strategy); | ||||
| policy.algo = | policy.algo = | ||||
| APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
| APPLY(m_dnn_opr->get_algorithm_info_heuristic( | |||||
| args..., workspace_limit, attr.first, attr.second), | args..., workspace_limit, attr.first, attr.second), | ||||
| m_layouts) | m_layouts) | ||||
| .desc; | .desc; | ||||
| Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| mgb_assert(algo, "Unknown algo description"); | mgb_assert(algo, "Unknown algo description"); | ||||
| std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( | std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( | ||||
| to_layout_array<Opr>(m_layouts), m_megdnn_opr); | |||||
| to_layout_array<Opr>(m_layouts), m_dnn_opr); | |||||
| FOREACH_OPR_TYPE_DISPATCH(sub_items, { | FOREACH_OPR_TYPE_DISPATCH(sub_items, { | ||||
| auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); | auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); | ||||
| @@ -389,7 +389,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile( | |||||
| const ExecutionStrategy& selected_strategy, bool enable_update) const { | const ExecutionStrategy& selected_strategy, bool enable_update) const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile"))) | ||||
| if (owner_graph()->options().no_profiling_on_shape_change) { | if (owner_graph()->options().no_profiling_on_shape_change) { | ||||
| auto policy = m_megdnn_opr->execution_policy(); | |||||
| auto policy = m_dnn_opr->execution_policy(); | |||||
| if (policy.algo.valid()) { | if (policy.algo.valid()) { | ||||
| return policy; | return policy; | ||||
| } | } | ||||
| @@ -439,9 +439,9 @@ typename AlgoChooser<Opr>::ImplAlgoDesc | |||||
| AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( | AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( | ||||
| const ExecutionStrategy& selected_strategy) const { | const ExecutionStrategy& selected_strategy) const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache"))) | ||||
| AlgoChooserProfileCache cache(m_cn, profile_name(m_megdnn_opr).c_str()); | |||||
| AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str()); | |||||
| typename Opr::Param origin_param = m_megdnn_opr->param(); | |||||
| typename Opr::Param origin_param = m_dnn_opr->param(); | |||||
| AlgoChooserProfileCache::Key cache_key{m_layouts.data(), m_layouts.size(), | AlgoChooserProfileCache::Key cache_key{m_layouts.data(), m_layouts.size(), | ||||
| &origin_param, sizeof(origin_param)}; | &origin_param, sizeof(origin_param)}; | ||||
| auto&& rst = cache.get(cache_key); | auto&& rst = cache.get(cache_key); | ||||
| @@ -504,7 +504,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( | |||||
| std::string layouts_str = format_fixlayouts<Opr>( | std::string layouts_str = format_fixlayouts<Opr>( | ||||
| m_layouts, arity_in, arity_out); | m_layouts, arity_in, arity_out); | ||||
| std::string msg = ssprintf( | std::string msg = ssprintf( | ||||
| "(mbg_opr : %s, layouts %s, with attribute(%s) and " | |||||
| "(opr : %s, layouts %s, with attribute(%s) and " | |||||
| "without attribute(%s)", | "without attribute(%s)", | ||||
| m_base_mgb_opr->dyn_typeinfo()->name, | m_base_mgb_opr->dyn_typeinfo()->name, | ||||
| layouts_str.c_str(), | layouts_str.c_str(), | ||||
| @@ -526,7 +526,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( | |||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| auto attr = extract_algo_attribute(selected_strategy); | auto attr = extract_algo_attribute(selected_strategy); | ||||
| policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
| policy.algo = APPLY(m_dnn_opr->get_algorithm_info_heuristic( | |||||
| args..., workspace_limit, attr.first, | args..., workspace_limit, attr.first, | ||||
| attr.second), | attr.second), | ||||
| m_layouts) | m_layouts) | ||||
| @@ -539,10 +539,10 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( | |||||
| } | } | ||||
| } | } | ||||
| Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| mgb_assert(algo, "Unknown algo description"); | mgb_assert(algo, "Unknown algo description"); | ||||
| std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( | std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( | ||||
| to_layout_array<Opr>(m_layouts), m_megdnn_opr); | |||||
| to_layout_array<Opr>(m_layouts), m_dnn_opr); | |||||
| FOREACH_OPR_TYPE_DISPATCH(sub_items, { | FOREACH_OPR_TYPE_DISPATCH(sub_items, { | ||||
| auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); | auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); | ||||
| @@ -571,11 +571,11 @@ template <typename Opr> | |||||
| size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes( | size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes( | ||||
| const ImplExecutionPolicy& policy) const { | const ImplExecutionPolicy& policy) const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes"))) | ||||
| m_megdnn_opr->execution_policy() = policy; | |||||
| m_dnn_opr->execution_policy() = policy; | |||||
| size_t result; | size_t result; | ||||
| if_constexpr<opr_supports_preprocess<Opr>()>( | if_constexpr<opr_supports_preprocess<Opr>()>( | ||||
| [&](auto _) { | [&](auto _) { | ||||
| auto&& opr = _(m_megdnn_opr); | |||||
| auto&& opr = _(m_dnn_opr); | |||||
| auto prep = this->construct_fake_preprocess_filter(); | auto prep = this->construct_fake_preprocess_filter(); | ||||
| PreprocessFilter<Opr>* prep_ptr = | PreprocessFilter<Opr>* prep_ptr = | ||||
| prep.valid() ? &prep.val() : nullptr; | prep.valid() ? &prep.val() : nullptr; | ||||
| @@ -587,7 +587,7 @@ size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes( | |||||
| }, | }, | ||||
| /* else */ | /* else */ | ||||
| [&](auto _) { | [&](auto _) { | ||||
| result = APPLY(_(m_megdnn_opr)->get_workspace_in_bytes(args...), | |||||
| result = APPLY(_(m_dnn_opr)->get_workspace_in_bytes(args...), | |||||
| m_layouts); | m_layouts); | ||||
| }); | }); | ||||
| return result; | return result; | ||||
| @@ -600,7 +600,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { | |||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) | ||||
| auto heu = choose_by_heuristic(m_execution_policy.strategy); | auto heu = choose_by_heuristic(m_execution_policy.strategy); | ||||
| auto&& ret = | auto&& ret = | ||||
| APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts); | |||||
| APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_layouts); | |||||
| bool found = false; | bool found = false; | ||||
| for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
| if (ret[i].desc == heu.algo) { | if (ret[i].desc == heu.algo) { | ||||
| @@ -610,7 +610,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { | |||||
| } | } | ||||
| } | } | ||||
| Algorithm* palgo = m_megdnn_opr->get_algorithm_from_desc(heu.algo); | |||||
| Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo); | |||||
| mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
| mgb_assert(found, | mgb_assert(found, | ||||
| "algo %s got by heuristic not found in " | "algo %s got by heuristic not found in " | ||||
| @@ -644,10 +644,10 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo( | |||||
| mgb_assert(param.shapes.size() == m_layouts.size()); | mgb_assert(param.shapes.size() == m_layouts.size()); | ||||
| for (size_t i = 0; i < param.shapes.size(); ++i) | for (size_t i = 0; i < param.shapes.size(); ++i) | ||||
| param.shapes[i] = m_layouts[i]; | param.shapes[i] = m_layouts[i]; | ||||
| param.opr_param = m_megdnn_opr->param(); | |||||
| param.opr_param = m_dnn_opr->param(); | |||||
| param.allow_weight_preprocess = m_allow_weight_preprocess; | param.allow_weight_preprocess = m_allow_weight_preprocess; | ||||
| Algorithm* palgo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| mgb_assert(palgo, "can not find algo when profile single algo"); | mgb_assert(palgo, "can not find algo when profile single algo"); | ||||
| auto rst = TimedProfiler<Opr>::profile(param, timeout); | auto rst = TimedProfiler<Opr>::profile(param, timeout); | ||||
| @@ -691,7 +691,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( | |||||
| policy.algo = algo.desc; | policy.algo = algo.desc; | ||||
| //! check negative attribute : skip negative attribute | //! check negative attribute : skip negative attribute | ||||
| auto palgo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| auto palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo); | |||||
| if (palgo->contain_attribute_any(target_attr.second)) { | if (palgo->contain_attribute_any(target_attr.second)) { | ||||
| mgb_log_debug( | mgb_log_debug( | ||||
| "skip algo %s, which matches the profile strategy required " | "skip algo %s, which matches the profile strategy required " | ||||
| @@ -748,12 +748,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( | |||||
| mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | ||||
| FixedTensorLayouts origin_layouts = m_layouts; | FixedTensorLayouts origin_layouts = m_layouts; | ||||
| typename Opr::Param origin_param = m_megdnn_opr->param(); | |||||
| typename Opr::Param origin_param = m_dnn_opr->param(); | |||||
| AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), | AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), | ||||
| origin_layouts.size(), &origin_param, | origin_layouts.size(), &origin_param, | ||||
| sizeof(origin_param)}; | sizeof(origin_param)}; | ||||
| AlgoChooserProfileCache cache(m_cn, profile_name(m_megdnn_opr).c_str()); | |||||
| AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str()); | |||||
| cache.put(cache_key, prof_rst); | cache.put(cache_key, prof_rst); | ||||
| MIDOUT_E | MIDOUT_E | ||||
| } | } | ||||
| @@ -766,7 +766,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::construct_fake_preprocess_filter() const { | |||||
| if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) { | if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) { | ||||
| if (!m_allow_weight_preprocess) | if (!m_allow_weight_preprocess) | ||||
| return; | return; | ||||
| auto opr = _(m_megdnn_opr); | |||||
| auto opr = _(m_dnn_opr); | |||||
| auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...), | auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...), | ||||
| m_layouts); | m_layouts); | ||||
| //! No preprocess layout means no need weight preprocess | //! No preprocess layout means no need weight preprocess | ||||
| @@ -312,10 +312,15 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | |||||
| double next_report_time = 0.5; | double next_report_time = 0.5; | ||||
| while (!ev_end->finished()) { | while (!ev_end->finished()) { | ||||
| if (timer.get_secs() >= next_report_time) { | if (timer.get_secs() >= next_report_time) { | ||||
| #if MGB_ENABLE_GETENV | |||||
| mgb_log_warn( | mgb_log_warn( | ||||
| "profiling conv algo %s already took %.3f/%.3f secs" | "profiling conv algo %s already took %.3f/%.3f secs" | ||||
| " (limit can be set by MGB_CONV_PROFILING_TIMEOUT) ", | " (limit can be set by MGB_CONV_PROFILING_TIMEOUT) ", | ||||
| algo->name(), timer.get_secs(), param.actual_timeout); | algo->name(), timer.get_secs(), param.actual_timeout); | ||||
| #else | |||||
| mgb_log_warn("profiling conv algo %s already took %.3f/%.3f secs", | |||||
| algo->name(), timer.get_secs(), param.actual_timeout); | |||||
| #endif | |||||
| next_report_time = timer.get_secs() + 1; | next_report_time = timer.get_secs() + 1; | ||||
| } | } | ||||
| using namespace std::literals; | using namespace std::literals; | ||||
| @@ -111,7 +111,7 @@ void Linspace::scn_do_execute() { | |||||
| stop.dtype(), stop.raw_ptr()).get_cast<double>(); | stop.dtype(), stop.raw_ptr()).get_cast<double>(); | ||||
| auto cn = comp_node(); | auto cn = comp_node(); | ||||
| auto &&opr = m_megdnn_opr; | |||||
| auto &&opr = m_dnn_opr; | |||||
| if (!opr || opr.comp_node() != cn) | if (!opr || opr.comp_node() != cn) | ||||
| opr = intl::create_megdnn_opr<megdnn::Linspace>(cn); | opr = intl::create_megdnn_opr<megdnn::Linspace>(cn); | ||||
| opr->param() = {startv, stopv, m_param.endpoint}; | opr->param() = {startv, stopv, m_param.endpoint}; | ||||
| @@ -122,7 +122,7 @@ void Linspace::scn_do_execute() { | |||||
| void Linspace::record_execute_deps(ExecDependencyArray& deps) { | void Linspace::record_execute_deps(ExecDependencyArray& deps) { | ||||
| deps.emplace_back( | deps.emplace_back( | ||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | |||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr))); | |||||
| } | } | ||||
| #if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
| @@ -184,7 +184,7 @@ cg::OperatorNodeBase::NodeProp* Eye::do_make_node_prop() const { | |||||
| void Eye::scn_do_execute() { | void Eye::scn_do_execute() { | ||||
| auto cn = comp_node(); | auto cn = comp_node(); | ||||
| auto &&opr = m_megdnn_opr; | |||||
| auto &&opr = m_dnn_opr; | |||||
| if (!opr || opr.comp_node() != cn) { | if (!opr || opr.comp_node() != cn) { | ||||
| opr = intl::create_megdnn_opr<megdnn::Eye>(cn); | opr = intl::create_megdnn_opr<megdnn::Eye>(cn); | ||||
| opr->param() = m_param; | opr->param() = m_param; | ||||
| @@ -196,7 +196,7 @@ void Eye::scn_do_execute() { | |||||
| void Eye::record_execute_deps(ExecDependencyArray& deps) { | void Eye::record_execute_deps(ExecDependencyArray& deps) { | ||||
| deps.emplace_back( | deps.emplace_back( | ||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | |||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr))); | |||||
| } | } | ||||
| #if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
| @@ -88,7 +88,7 @@ namespace mixin { | |||||
| template<class Opr> | template<class Opr> | ||||
| class IndexingMultiAxisVecMegDNNOprHolder { | class IndexingMultiAxisVecMegDNNOprHolder { | ||||
| intl::UniqPtrWithCN<Opr> m_megdnn_opr; | |||||
| intl::UniqPtrWithCN<Opr> m_dnn_opr; | |||||
| protected: | protected: | ||||
| Opr& megdnn_opr(cg::SingleCNOperatorNodeBase& self); | Opr& megdnn_opr(cg::SingleCNOperatorNodeBase& self); | ||||
| @@ -136,7 +136,7 @@ namespace mixin { | |||||
| virtual void create_megdnn_opr() = 0; | virtual void create_megdnn_opr() = 0; | ||||
| megdnn::OperatorBase* megdnn_opr() const { | megdnn::OperatorBase* megdnn_opr() const { | ||||
| return m_megdnn_opr.get(); | |||||
| return m_dnn_opr.get(); | |||||
| } | } | ||||
| void set_megdnn_opr(std::unique_ptr<megdnn::OperatorBase> opr); | void set_megdnn_opr(std::unique_ptr<megdnn::OperatorBase> opr); | ||||
| @@ -146,7 +146,7 @@ namespace mixin { | |||||
| cg::GraphExecutable::ExecDependencyArray& deps); | cg::GraphExecutable::ExecDependencyArray& deps); | ||||
| private: | private: | ||||
| std::unique_ptr<megdnn::OperatorBase> m_megdnn_opr; | |||||
| std::unique_ptr<megdnn::OperatorBase> m_dnn_opr; | |||||
| }; | }; | ||||
| class MegDNNOprHolderBwdStaticInfer: public MegDNNOprHolder { | class MegDNNOprHolderBwdStaticInfer: public MegDNNOprHolder { | ||||
| @@ -23,7 +23,7 @@ namespace opr { | |||||
| namespace intl { | namespace intl { | ||||
| MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { | MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { | ||||
| UniqPtrWithCN<megdnn::RNGBase> m_megdnn_opr; | |||||
| UniqPtrWithCN<megdnn::RNGBase> m_dnn_opr; | |||||
| void ensure_megdnn_opr(); | void ensure_megdnn_opr(); | ||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| @@ -69,7 +69,7 @@ public: | |||||
| using FixedTensorLayouts = std::array<TensorLayout, arity>; | using FixedTensorLayouts = std::array<TensorLayout, arity>; | ||||
| class AlgoChooserHelper { | class AlgoChooserHelper { | ||||
| FixedTensorLayouts m_layouts; | FixedTensorLayouts m_layouts; | ||||
| Opr* m_megdnn_opr; | |||||
| Opr* m_dnn_opr; | |||||
| std::string m_param; | std::string m_param; | ||||
| const cg::OperatorNodeBase* m_base_mgb_opr; | const cg::OperatorNodeBase* m_base_mgb_opr; | ||||
| CompNode m_cn; | CompNode m_cn; | ||||
| @@ -84,7 +84,7 @@ public: | |||||
| const megdnn::param::ExecutionPolicy& execution_policy, | const megdnn::param::ExecutionPolicy& execution_policy, | ||||
| bool allow_weight_preprocess); | bool allow_weight_preprocess); | ||||
| Opr* megdnn_opr() const { return m_megdnn_opr; } | |||||
| Opr* megdnn_opr() const { return m_dnn_opr; } | |||||
| const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; } | const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; } | ||||
| @@ -106,7 +106,7 @@ public: | |||||
| megdnn::Algorithm* get_algorithm_from_desc( | megdnn::Algorithm* get_algorithm_from_desc( | ||||
| const megdnn::Algorithm::Info::Desc& desc) const { | const megdnn::Algorithm::Info::Desc& desc) const { | ||||
| return m_megdnn_opr->get_algorithm_from_desc(desc); | |||||
| return m_dnn_opr->get_algorithm_from_desc(desc); | |||||
| } | } | ||||
| const FixedTensorLayouts& layouts() const { return m_layouts; } | const FixedTensorLayouts& layouts() const { return m_layouts; } | ||||
| @@ -72,7 +72,7 @@ MGB_DEFINE_OPR_CLASS(Linspace, cg::SingleCNOperatorNodeBase) // { | |||||
| private: | private: | ||||
| const Param m_param; | const Param m_param; | ||||
| intl::UniqPtrWithCN<megdnn::Linspace> m_megdnn_opr; | |||||
| intl::UniqPtrWithCN<megdnn::Linspace> m_dnn_opr; | |||||
| void scn_do_execute() override; | void scn_do_execute() override; | ||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| @@ -97,7 +97,7 @@ MGB_DEFINE_OPR_CLASS(Eye, cg::SingleCNOperatorNodeBase) // { | |||||
| private: | private: | ||||
| const Param m_param; | const Param m_param; | ||||
| intl::UniqPtrWithCN<megdnn::Eye> m_megdnn_opr; | |||||
| intl::UniqPtrWithCN<megdnn::Eye> m_dnn_opr; | |||||
| void scn_do_execute() override; | void scn_do_execute() override; | ||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| @@ -279,6 +279,7 @@ void VarSanityCheck::check_single_input(bool add_debug_log, | |||||
| } | } | ||||
| if (checksum != checksum_expect) { | if (checksum != checksum_expect) { | ||||
| #if MGB_ENABLE_GETENV | |||||
| mgb_throw(Error, | mgb_throw(Error, | ||||
| "var sanity check failed: var: %s" | "var sanity check failed: var: %s" | ||||
| " (checksum: expect=%s got=%s); receiver: %s{%s}(%zu);" | " (checksum: expect=%s got=%s); receiver: %s{%s}(%zu);" | ||||
| @@ -288,6 +289,15 @@ void VarSanityCheck::check_single_input(bool add_debug_log, | |||||
| str(checksum_expect).c_str(), str(checksum).c_str(), | str(checksum_expect).c_str(), str(checksum).c_str(), | ||||
| recv_opr->cname(), recv_opr->dyn_typeinfo()->name, | recv_opr->cname(), recv_opr->dyn_typeinfo()->name, | ||||
| recv_opr->id(), var->id(), !add_debug_log); | recv_opr->id(), var->id(), !add_debug_log); | ||||
| #else | |||||
| mgb_throw(Error, | |||||
| "var sanity check failed: var: %s" | |||||
| " (checksum: expect=%s got=%s); receiver: %s{%s}(%zu);", | |||||
| cg::dump_var_info({var}).c_str(), | |||||
| str(checksum_expect).c_str(), str(checksum).c_str(), | |||||
| recv_opr->cname(), recv_opr->dyn_typeinfo()->name, | |||||
| recv_opr->id()); | |||||
| #endif | |||||
| } | } | ||||
| } | } | ||||
| @@ -292,7 +292,7 @@ ExternCOprRunner::ExternCOprRunner(std::string& name, | |||||
| auto size_diff = sizeof(MGBOprDesc) - m_desc->size; | auto size_diff = sizeof(MGBOprDesc) - m_desc->size; | ||||
| is_loader_support_dynamic_param = (0 == size_diff) ? true : false; | is_loader_support_dynamic_param = (0 == size_diff) ? true : false; | ||||
| mgb_assert(0 == size_diff || sizeof(ExternCOprParam*) == size_diff, | mgb_assert(0 == size_diff || sizeof(ExternCOprParam*) == size_diff, | ||||
| "invalid MGBOprDesc size: expect=%zu got=%u, may caused by " | |||||
| "invalid OprDesc size: expect=%zu got=%u, may caused by " | |||||
| "extern_c_opr.h mismatch, please confirm that the " | "extern_c_opr.h mismatch, please confirm that the " | ||||
| "extern_c_opr.h used when compiling the loader is consistent " | "extern_c_opr.h used when compiling the loader is consistent " | ||||
| "with the runtime caller build used", | "with the runtime caller build used", | ||||
| @@ -531,8 +531,8 @@ cg::OperatorNodeBase* ExternCOprRunner::shallow_copy( | |||||
| } | } | ||||
| MGBTensorShape ExternCOprRunner::tensor_shape_to_c(const TensorShape& shape) { | MGBTensorShape ExternCOprRunner::tensor_shape_to_c(const TensorShape& shape) { | ||||
| mgb_assert(shape.ndim <= MGB_TENSOR_MAX_NDIM, "shape ndim too large: %zu", | |||||
| shape.ndim); | |||||
| mgb_throw_if(shape.ndim > MGB_TENSOR_MAX_NDIM, MegBrainError, | |||||
| "shape ndim too large: %zu", shape.ndim); | |||||
| MGBTensorShape ret; | MGBTensorShape ret; | ||||
| ret.ndim = shape.ndim; | ret.ndim = shape.ndim; | ||||
| for (size_t i = 0; i < shape.ndim; ++i) { | for (size_t i = 0; i < shape.ndim; ++i) { | ||||
| @@ -41,7 +41,8 @@ DType OprLoadContextRawPOD::read_param() { | |||||
| if (m_check_param_tag) { | if (m_check_param_tag) { | ||||
| uint32_t tag; | uint32_t tag; | ||||
| read_raw(&tag, sizeof(tag)); | read_raw(&tag, sizeof(tag)); | ||||
| mgb_assert(tag == megdnn::param::FakeSerializedDType::TAG); | |||||
| mgb_throw_if(tag != megdnn::param::FakeSerializedDType::TAG, | |||||
| MegBrainError, "ERROR tag"); | |||||
| } | } | ||||
| return serialization::deserialize_dtype( | return serialization::deserialize_dtype( | ||||
| [this](void* data, size_t len) { read_raw(data, len); }); | [this](void* data, size_t len) { read_raw(data, len); }); | ||||