GitOrigin-RevId: 6d211ca167
tags/v1.3.0
| @@ -188,6 +188,7 @@ public: | |||
| using AlgorithmInfo = detail::Algorithm::Info; | |||
| using AlgorithmDesc = detail::Algorithm::Info::Desc; | |||
| using Algorithm = detail::Algorithm; | |||
| /*! | |||
| * \brief get a string representation for current algorithm set; | |||
| * | |||
| @@ -209,6 +210,8 @@ public: | |||
| return m_execution_policy; | |||
| } | |||
| virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | |||
| protected: | |||
| ~MultiAlgoOpr() = default; | |||
| @@ -38,11 +38,12 @@ namespace megdnn { | |||
| return algo_pack().all_algos_map().at(desc); \ | |||
| } | |||
| #define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
| _opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ | |||
| megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
| algo_pack().all_algos_map().end()); \ | |||
| return algo_pack().all_algos_map().at(desc); \ | |||
| #define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ | |||
| _opr::Algorithm* _opr::get_algorithm_from_desc( \ | |||
| const AlgorithmDesc& desc) { \ | |||
| megdnn_assert(algo_pack().all_algos_map().find(desc) != \ | |||
| algo_pack().all_algos_map().end()); \ | |||
| return algo_pack().all_algos_map().at(desc); \ | |||
| } | |||
| /** | |||
| @@ -34,7 +34,8 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
| std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | |||
| false); | |||
| } | |||
| return opr->get_algo_from_desc(ret.desc); | |||
| return static_cast<typename Opr::AlgoBase*>( | |||
| opr->get_algorithm_from_desc(ret.desc)); | |||
| } | |||
| /*! | |||
| @@ -43,7 +44,6 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
| */ | |||
| template <class Opr, typename... Args> | |||
| typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | |||
| typename Opr::AlgorithmInfo ret; | |||
| auto set = opr->execution_policy().algo; | |||
| if (set.valid()) { | |||
| return opr->algo_pack().construct_and_get_algo(set.desc); | |||
| @@ -35,7 +35,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -39,7 +39,7 @@ public: | |||
| bool is_thread_safe() const override { return true; } | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
| @@ -69,7 +69,7 @@ public: | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| @@ -86,6 +86,28 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||
| workspace_limit_in_bytes, reproducible); | |||
| } | |||
| ConvolutionForwardImpl::Algorithm* | |||
| ConvolutionForwardImpl::get_algorithm_from_desc( | |||
| const ConvolutionForward::AlgorithmDesc& desc) { | |||
| auto conv_param = param(); | |||
| auto convbias_opr = this->handle()->create_operator<ConvBiasForward>(); | |||
| convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, | |||
| conv_param.mode, | |||
| conv_param.sparse, | |||
| conv_param.format, | |||
| conv_param.pad_h, | |||
| conv_param.pad_w, | |||
| conv_param.stride_h, | |||
| conv_param.stride_w, | |||
| conv_param.dilate_h, | |||
| conv_param.dilate_w, | |||
| conv_param.compute_mode}; | |||
| convbias_opr->execution_policy() = {this->execution_policy().algo}; | |||
| return static_cast<ConvBiasForwardImpl*>(convbias_opr.get()) | |||
| ->get_algorithm_from_desc(desc); | |||
| } | |||
| std::vector<ConvolutionForwardImpl::Algorithm*> | |||
| ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| @@ -46,6 +46,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
| megdnn_throw("cuda exec_preprocess has not implemeted yet"); | |||
| } | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| struct ConvBiasExtraData{ | |||
| std::unique_ptr<ConvBiasForward> convbias_opr; | |||
| @@ -98,7 +100,7 @@ public: | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -152,7 +154,7 @@ public: | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -42,7 +42,7 @@ public: | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -92,7 +92,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -143,7 +143,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -46,7 +46,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -97,7 +97,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -151,7 +151,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -33,7 +33,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -65,7 +65,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -98,7 +98,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -46,7 +46,7 @@ public: | |||
| static const AlgoPack& algo_pack() { | |||
| return sm_algo_pack; | |||
| } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
| @@ -29,8 +29,7 @@ public: | |||
| class AlgoDefault; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -454,8 +454,8 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||
| return algos; | |||
| } | |||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||
| const AlgorithmDesc& desc) const { | |||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| if (!desc.valid()) { | |||
| return nullptr; | |||
| } else { | |||
| @@ -495,7 +495,7 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( | |||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||
| const NCBKernSizeParam& param, size_t workspace_size) { | |||
| if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
| if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||
| return algo; | |||
| } | |||
| if (!m_prev_selected_algo || | |||
| @@ -381,7 +381,7 @@ private: | |||
| bool is_naive_algo(ConvBiasImpl::Algorithm* algo); | |||
| Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| //! get algorithm set by user or by heuristic | |||
| Algorithm* get_algorithm( | |||
| @@ -361,8 +361,8 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { | |||
| return ret; | |||
| } | |||
| ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( | |||
| const AlgorithmDesc& desc) const { | |||
| ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| if (!desc.valid()) { | |||
| return nullptr; | |||
| } else { | |||
| @@ -387,7 +387,7 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( | |||
| ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | |||
| const NCBKernSizeParam& param, size_t workspace_size) { | |||
| if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
| if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||
| return algo; | |||
| } | |||
| if (!m_prev_selected_algo || | |||
| @@ -749,8 +749,8 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | |||
| } | |||
| ConvolutionBackwardDataImpl::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algo_from_desc( | |||
| const AlgorithmDesc& desc) const { | |||
| ConvolutionBackwardDataImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| if (!desc.valid()) { | |||
| return nullptr; | |||
| } else { | |||
| @@ -783,7 +783,7 @@ ConvolutionBackwardDataImpl::get_algo_from_desc( | |||
| ConvolutionBackwardDataImpl::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { | |||
| if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
| if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { | |||
| return algo; | |||
| } | |||
| if (!m_prev_selected_algo || | |||
| @@ -284,7 +284,7 @@ private: | |||
| NCBKernSizeParam m_prev_selected_algo_sizep; | |||
| Algorithm* m_prev_selected_algo = nullptr; | |||
| Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| bool is_naive_algo(ConvolutionImpl::Algorithm* algo); | |||
| Algorithm* get_algorithm( | |||
| const NCBKernSizeParam& param, | |||
| @@ -493,7 +493,7 @@ private: | |||
| class AlgoDirect; | |||
| class AlgoMatrixMul; | |||
| class AlgoPack; | |||
| Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| public: | |||
| //! maintain all the algos of in the opr of fallback | |||
| @@ -96,7 +96,7 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | |||
| return gemv_algos; | |||
| } | |||
| MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc( | |||
| MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| if (!desc.valid()) { | |||
| return nullptr; | |||
| @@ -133,7 +133,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||
| auto kern_size_param = make_kern_size_param(A, B, C); | |||
| if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { | |||
| if (auto algo = static_cast<AlgoBase*>( | |||
| get_algorithm_from_desc(execution_policy().algo.desc))) { | |||
| megdnn_assert(algo->get_workspace(kern_size_param) < | |||
| workspace_limit_in_bytes); | |||
| auto cur = megdnn::get_reproducible_algo<MatrixMulImpl>(algo, | |||
| @@ -238,7 +238,8 @@ private: | |||
| class AlgoPack; | |||
| //! maintain all the algos of in the opr of fallback | |||
| static const AlgoPack& algo_pack(); | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| public: | |||
| /** | |||
| @@ -138,4 +138,12 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| BatchConvBiasForward::Algorithm* | |||
| BatchConvBiasForwardImpl::get_algorithm_from_desc(const AlgorithmDesc& desc) { | |||
| Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
| ->default_batch_conv_bias_fwd_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -39,6 +39,8 @@ public: | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| @@ -81,6 +81,15 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| ->default_batched_matmul_fwd_algo(); | |||
| } | |||
| BatchedMatrixMulForward::Algorithm* | |||
| BatchedMatrixMulForwardImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
| ->default_batched_matmul_fwd_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -34,6 +34,8 @@ public: | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /* reproducible */) override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| private: | |||
| @@ -256,6 +256,15 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| ConvBiasForward::Algorithm* | |||
| ConvBiasForwardImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| const char* ConvBiasForwardImpl::get_algorithm_set_name() const { | |||
| return "DEFAULT"; | |||
| } | |||
| @@ -64,6 +64,8 @@ public: | |||
| _megdnn_workspace) override {} | |||
| const char* get_algorithm_set_name() const override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| }; | |||
| void handle_z_inp_and_activation_naive( | |||
| @@ -285,6 +285,14 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| std::vector<ConvolutionBackwardData::Algorithm *> | |||
| ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| @@ -309,6 +317,15 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| ConvolutionBackwardData::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| std::vector<ConvolutionBackwardFilter::Algorithm *> | |||
| ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| @@ -333,6 +350,15 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| ConvolutionBackwardFilter::Algorithm* | |||
| ConvolutionBackwardFilterImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| const char* ConvolutionForwardImpl::get_algorithm_set_name() const { | |||
| return "DEFAULT"; | |||
| } | |||
| @@ -52,6 +52,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
| return {}; | |||
| } | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| }; | |||
| @@ -74,6 +76,8 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
| const TensorLayout&) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| }; | |||
| class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
| @@ -95,6 +99,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
| const TensorLayout&) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| }; | |||
| } // namespace naive | |||
| @@ -6,15 +6,15 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "./helper.h" | |||
| #include "./opr_impl.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/common/utils.h" | |||
| #include "megdnn/dtype.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include <cstring> | |||
| @@ -25,93 +25,95 @@ using namespace megdnn; | |||
| using namespace naive; | |||
| void Convolution3DForwardImpl::exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in filter, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| _megdnn_tensor_in filter, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) { | |||
| auto filter_meta = check_exec( | |||
| src.layout, filter.layout, dst.layout, workspace.size); | |||
| switch (param().data_type) { | |||
| case Param::DataType::FLOAT: | |||
| #define cb(dt) do { \ | |||
| if (src.layout.dtype == dt()) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||
| convolution3d::forward< \ | |||
| ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
| src, filter, dst, filter_meta); \ | |||
| ); \ | |||
| return; \ | |||
| } \ | |||
| } while(0); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
| auto filter_meta = check_exec(src.layout, filter.layout, dst.layout, | |||
| workspace.size); | |||
| switch (param().data_type) { | |||
| case Param::DataType::FLOAT: | |||
| #define cb(dt) \ | |||
| do { \ | |||
| if (src.layout.dtype == dt()) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<HandleImpl*>(handle()), \ | |||
| convolution3d::forward< \ | |||
| ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
| src, filter, dst, filter_meta);); \ | |||
| return; \ | |||
| } \ | |||
| } while (0); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
| #undef cb | |||
| break; | |||
| case Param::DataType::FLOAT_IO16xC32: | |||
| MEGDNN_INC_FLOAT16( | |||
| MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), | |||
| convolution3d::forward< | |||
| dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA dt_float32>( | |||
| src, filter, dst, filter_meta);)); | |||
| return; | |||
| break; | |||
| case Param::DataType::FLOAT_IO16xC32: | |||
| MEGDNN_INC_FLOAT16(MEGDNN_DISPATCH_CPU_KERN( | |||
| static_cast<HandleImpl*>(handle()), | |||
| convolution3d::forward< | |||
| dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA | |||
| dt_float32>(src, filter, dst, | |||
| filter_meta);)); | |||
| return; | |||
| } | |||
| megdnn_assert_internal(0); | |||
| } | |||
| megdnn_assert_internal(0); | |||
| } MIDOUT_END(); | |||
| MIDOUT_END(); | |||
| } | |||
| void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| auto filter_meta = check_exec( | |||
| filter.layout, diff.layout, grad.layout, workspace.size); | |||
| #define cb(dt) do { \ | |||
| if (filter.layout.dtype == dt()) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||
| convolution3d::backward_data< \ | |||
| ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
| filter, diff, grad, filter_meta);); \ | |||
| return; \ | |||
| } \ | |||
| } while(0); | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| auto filter_meta = | |||
| check_exec(filter.layout, diff.layout, grad.layout, workspace.size); | |||
| #define cb(dt) \ | |||
| do { \ | |||
| if (filter.layout.dtype == dt()) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<HandleImpl*>(handle()), \ | |||
| convolution3d::backward_data< \ | |||
| ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
| filter, diff, grad, filter_meta);); \ | |||
| return; \ | |||
| } \ | |||
| } while (0); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
| #undef cb | |||
| megdnn_assert_internal(0); | |||
| } | |||
| void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) | |||
| { | |||
| auto filter_meta = check_exec( | |||
| src.layout, diff.layout, grad.layout, workspace.size); | |||
| #define cb(dt) do { \ | |||
| if (src.layout.dtype == dt()) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN(static_cast<HandleImpl *>(handle()), \ | |||
| convolution3d::backward_filter< \ | |||
| ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
| src, diff, grad, filter_meta);); \ | |||
| return; \ | |||
| } \ | |||
| } while(0); | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| auto filter_meta = | |||
| check_exec(src.layout, diff.layout, grad.layout, workspace.size); | |||
| #define cb(dt) \ | |||
| do { \ | |||
| if (src.layout.dtype == dt()) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<HandleImpl*>(handle()), \ | |||
| convolution3d::backward_filter< \ | |||
| ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ | |||
| src, diff, grad, filter_meta);); \ | |||
| return; \ | |||
| } \ | |||
| } while (0); | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
| #undef cb | |||
| megdnn_assert_internal(0); | |||
| } | |||
| std::vector<Convolution3DForward::Algorithm *> | |||
| Convolution3DForwardImpl:: get_all_algorithms(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| { | |||
| return {static_cast<HandleImpl *>(handle())->default_conv3d_fwd_algo()}; | |||
| std::vector<Convolution3DForward::Algorithm*> | |||
| Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | |||
| } | |||
| Convolution3DForward::Algorithm* | |||
| @@ -130,11 +132,20 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| std::vector<Convolution3DBackwardData::Algorithm *> | |||
| Convolution3DBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| { | |||
| return {static_cast<HandleImpl *>(handle())->default_conv3d_bwd_data_algo()}; | |||
| Convolution3DForward::Algorithm* | |||
| Convolution3DForwardImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| std::vector<Convolution3DBackwardData::Algorithm*> | |||
| Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | |||
| } | |||
| Convolution3DBackwardData::Algorithm* | |||
| @@ -154,11 +165,21 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| std::vector<Convolution3DBackwardFilter::Algorithm *> | |||
| Convolution3DBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_filter_algo()}; | |||
| Convolution3DBackwardData::Algorithm* | |||
| Convolution3DBackwardDataImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| std::vector<Convolution3DBackwardFilter::Algorithm*> | |||
| Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_conv3d_bwd_filter_algo()}; | |||
| } | |||
| Convolution3DBackwardFilter::Algorithm* | |||
| @@ -179,6 +200,15 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| Convolution3DBackwardFilter::Algorithm* | |||
| Convolution3DBackwardFilterImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
| ->default_conv3d_bwd_filter_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| const char* Convolution3DForwardImpl::get_algorithm_set_name() const { | |||
| return "DEFAULT"; | |||
| } | |||
| @@ -6,81 +6,79 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class Convolution3DForwardImpl: public Convolution3DForward { | |||
| public: | |||
| using Convolution3DForward::Convolution3DForward; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in filter, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &filter, | |||
| const TensorLayout &dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| const char* get_algorithm_set_name() const override; | |||
| class Convolution3DForwardImpl : public Convolution3DForward { | |||
| public: | |||
| using Convolution3DForward::Convolution3DForward; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| }; | |||
| class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { | |||
| public: | |||
| using Convolution3DBackwardData::Convolution3DBackwardData; | |||
| void exec(_megdnn_tensor_in filter, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { | |||
| public: | |||
| using Convolution3DBackwardData::Convolution3DBackwardData; | |||
| void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| const char* get_algorithm_set_name() const override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| }; | |||
| class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { | |||
| public: | |||
| using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| const char* get_algorithm_set_name() const override; | |||
| class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { | |||
| public: | |||
| using Convolution3DBackwardFilter::Convolution3DBackwardFilter; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override; | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -48,6 +48,10 @@ public: | |||
| return "DEFORMABLE_CONV2_NAIVE"; | |||
| }; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||
| return {}; | |||
| } | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| _megdnn_tensor_in offset, _megdnn_tensor_in mask, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
| @@ -84,6 +88,10 @@ public: | |||
| return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE"; | |||
| }; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||
| return {}; | |||
| } | |||
| void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, | |||
| _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, | |||
| _megdnn_tensor_out filter_grad, | |||
| @@ -130,6 +138,10 @@ public: | |||
| return "DEFORMABLE_CONV2_BWD_DATA_NAIVE"; | |||
| }; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { | |||
| return {}; | |||
| } | |||
| void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, | |||
| _megdnn_tensor_in offset, _megdnn_tensor_in mask, | |||
| _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, | |||
| @@ -175,6 +175,15 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| LocalShareForward::Algorithm* | |||
| LocalShareForwardImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| std::vector<LocalShareBackwardData::Algorithm*> | |||
| LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||
| const TensorLayout&, | |||
| @@ -200,6 +209,15 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| LocalShareBackwardData::Algorithm* | |||
| LocalShareBackwardDataImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
| ->default_local_share_bwd_data_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| std::vector<LocalShareBackwardFilter::Algorithm*> | |||
| LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||
| const TensorLayout&, | |||
| @@ -225,4 +243,13 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| LocalShareBackwardFilter::Algorithm* | |||
| LocalShareBackwardFilterImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = static_cast<HandleImpl*>(handle()) | |||
| ->default_local_share_bwd_filter_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| @@ -35,6 +36,7 @@ public: | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /*reproducible*/) override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| }; | |||
| @@ -59,6 +61,7 @@ public: | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /*reproducible*/) override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| }; | |||
| @@ -83,6 +86,7 @@ public: | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /*reproducible*/) override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| }; | |||
| @@ -95,6 +95,14 @@ MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
| return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||
| } | |||
| MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||
| megdnn_assert(desc == ret->info().desc); | |||
| return ret; | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -35,6 +35,8 @@ public: | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /* reproducible */) override; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| private: | |||
| @@ -29,8 +29,8 @@ public: | |||
| class AlgoBlas; | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| @@ -66,7 +66,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -112,7 +112,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -158,7 +158,7 @@ public: | |||
| class AlgoPack; | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| private: | |||
| @@ -29,7 +29,7 @@ public: | |||
| class AlgoPack; | |||
| static const AlgoPack& algo_pack() { return sm_algo_pack; } | |||
| static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | |||
| private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| @@ -41,6 +41,7 @@ private: | |||
| const TensorLayout& /*C*/, | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /*reproducible*/) override; | |||
| const char* get_algorithm_set_name() const override { | |||
| return "ROCM MATMUL"; | |||
| } | |||
| @@ -2204,6 +2204,10 @@ public: | |||
| const TensorLayout& p2, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible)); | |||
| MOCK_METHOD1(get_algorithm_from_desc, | |||
| Algorithm*(const AlgorithmDesc&)); | |||
| protected: | |||
| const char* get_algorithm_set_name() const override { | |||
| return m_algorithm_set_name; | |||