GitOrigin-RevId: b7a1dc62d8
tags/v1.10.0
| @@ -51,15 +51,6 @@ PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; | |||
| size_t PoolingImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) { | |||
| TensorLayoutArray layouts{src, dst}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| auto param = make_pooling_kern_szie_param(this, src, dst); | |||
| auto algo = get_algorithm(this, src, dst); | |||
| if (!is_fallback_algo(algo)) { | |||
| @@ -13,14 +13,6 @@ namespace megdnn { | |||
| template <class Opr, typename... Args> | |||
| size_t get_dnn_workspace(Opr* opr, Args&&... args) { | |||
| TensorLayoutArray layouts{{args...}}; | |||
| AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), layouts.data(), | |||
| layouts.size(), &opr->param(), sizeof(opr->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||
| return get_algorithm(opr, std::forward<Args>(args)...) | |||
| ->get_workspace_in_bytes(size_args); | |||
| @@ -32,6 +24,7 @@ size_t get_dnn_workspace(Opr* opr, Args&&... args) { | |||
| template <class Opr, typename... Args> | |||
| typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
| typename Opr::AlgorithmDesc ret; | |||
| // first check self configured algorithm | |||
| auto set = opr->execution_policy().algo; | |||
| if (set.valid()) { | |||
| ret = set; | |||
| @@ -40,10 +33,12 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
| AlgorithmCache::Key key{opr->handle(), opr->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &opr->param(), sizeof(opr->param())}; | |||
| // then get from global algorithm cache | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| ret = rst.policy.algo; | |||
| } else { | |||
| // finally get pre-defined heuristic algorithm | |||
| ret = opr->get_algorithm_info_heuristic( | |||
| std::forward<Args>(args)..., | |||
| std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT, | |||
| @@ -44,14 +44,6 @@ WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle( | |||
| size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | |||
| const TensorLayout& z, const TensorLayout& dst) { | |||
| TensorLayoutArray layouts{src, flt, bias, z, dst}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes(); | |||
| } | |||
| @@ -187,15 +187,6 @@ void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>( | |||
| size_t ConvBiasForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias, | |||
| const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) { | |||
| TensorLayoutArray layouts{src, flt, bias, z, dst}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| size_t float_workspace_size = 0; | |||
| if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) { | |||
| @@ -66,15 +66,6 @@ void ConvolutionForwardImpl::exec( | |||
| size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| TensorLayoutArray layouts{filter, diff, grad}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| size_t workspace_size = 0; | |||
| auto flt_dt = filter.dtype.enumv(); | |||
| auto grad_dt = grad.dtype.enumv(); | |||
| @@ -178,15 +169,6 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | |||
| size_t workspace_size = 0; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| TensorLayoutArray layouts{src, diff, grad}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| auto src_dt = src.dtype.enumv(); | |||
| auto grad_dt = grad.dtype.enumv(); | |||
| auto diff_dt = diff.dtype.enumv(); | |||
| @@ -397,14 +397,6 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle( | |||
| size_t PoolingForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) { | |||
| TensorLayoutArray layouts{src, dst}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes(); | |||
| } | |||
| namespace { | |||
| @@ -649,14 +641,6 @@ WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | |||
| size_t PoolingBackwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| TensorLayoutArray layouts{src, dst, diff, grad}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| return get_workspace_bundle(nullptr, src, dst, diff, grad).total_size_in_bytes(); | |||
| } | |||
| @@ -104,15 +104,6 @@ std::vector<ConvolutionForwardImpl::Algorithm*> ConvolutionForwardImpl:: | |||
| size_t ConvolutionForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, | |||
| const PreprocessedFilter*) { | |||
| TensorLayoutArray layouts{src, filter, dst}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| AlgoBase::SizeArgs args(this, src, filter, dst); | |||
| return get_algorithm(this, src, filter, dst)->get_workspace_in_bytes(args); | |||
| } | |||
| @@ -198,15 +189,6 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: | |||
| size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| TensorLayoutArray layouts{filter, diff, grad}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| AlgoBase::SizeArgs args(this, filter, diff, grad); | |||
| return get_algorithm(this, filter, diff, grad)->get_workspace_in_bytes(args); | |||
| } | |||
| @@ -282,15 +264,6 @@ ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl:: | |||
| size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { | |||
| TensorLayoutArray layouts{src, diff, grad}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| AlgoBase::SizeArgs args(this, src, diff, grad); | |||
| return get_algorithm(this, src, diff, grad)->get_workspace_in_bytes(args); | |||
| } | |||
| @@ -35,15 +35,6 @@ WorkspaceBundle megdnn::x86::get_bundle( | |||
| size_t PoolingImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) { | |||
| TensorLayoutArray layouts{src, dst}; | |||
| AlgorithmCache::Key key{this->handle(), this->get_opr_type(), | |||
| layouts.data(), layouts.size(), | |||
| &this->param(), sizeof(this->param())}; | |||
| auto rst = AlgorithmCache::instance().get(key); | |||
| if (rst.policy.algo.valid()) { | |||
| return rst.workspace; | |||
| } | |||
| auto algo = get_algorithm(this, src, dst); | |||
| if (!is_fallback_algo(algo)) { | |||
| if (is_supported(SIMDType::SSE) && src.dtype == dtype::Float32() && | |||
| @@ -351,7 +351,7 @@ class TimedFuncInvokerImpl final : public TimedFuncInvoker { | |||
| } else { | |||
| CHECK_SYS_ERR(cur_recv); | |||
| } | |||
| mgb_assert(cur_recv > 0); | |||
| mgb_assert(cur_recv >= 0); | |||
| dest += cur_recv; | |||
| size -= cur_recv; | |||
| } | |||
| @@ -950,10 +950,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( | |||
| algo.desc.name.c_str(), layouts_str.c_str()); | |||
| timer.reset(); | |||
| MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); } | |||
| // megbrain catched exception | |||
| MGB_CATCH(std::exception & exc, { | |||
| mgb_log_warn("caught exception during %s: %s", msg.c_str(), exc.what()); | |||
| mgb_log_debug("caught exception during %s: %s", msg.c_str(), exc.what()); | |||
| continue; | |||
| }) | |||
| // megbrain uncatched exception | |||
| MGB_CATCH(..., { | |||
| mgb_log_warn("caught exception during %s", msg.c_str()); | |||
| continue; | |||