GitOrigin-RevId: e3734e4531
tags/v1.6.0
| @@ -315,7 +315,7 @@ public: | |||||
| /*! | /*! | ||||
| * \brief get a string representation for current algorithm set; | * \brief get a string representation for current algorithm set; | ||||
| * | * | ||||
| * get_all_algorithms() may return different algorithms only if | |||||
| * get_all_algorithms_safe() may return different algorithms only if | |||||
| * algorithm set name differs. This is used for checking cache | * algorithm set name differs. This is used for checking cache | ||||
| * validity. | * validity. | ||||
| */ | */ | ||||
| @@ -354,6 +354,15 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
| * algorithm by heuristic. | * algorithm by heuristic. | ||||
| @@ -378,6 +387,8 @@ protected: | |||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& p0, const TensorLayout& p1) = 0; | const TensorLayout& p0, const TensorLayout& p1) = 0; | ||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1) = 0; | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
| @@ -412,6 +423,16 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
| * algorithm by heuristic. | * algorithm by heuristic. | ||||
| @@ -438,6 +459,9 @@ protected: | |||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2) = 0; | const TensorLayout& p2) = 0; | ||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2) = 0; | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
| @@ -463,7 +487,7 @@ public: | |||||
| using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
| //! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | const TensorLayout& p1, | ||||
| const TensorLayout& p2, | const TensorLayout& p2, | ||||
| const TensorLayout& p3) { | const TensorLayout& p3) { | ||||
| @@ -474,6 +498,17 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
| * algorithm by heuristic. | * algorithm by heuristic. | ||||
| @@ -500,6 +535,9 @@ protected: | |||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | const TensorLayout& p2, const TensorLayout& p3) = 0; | ||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
| @@ -537,6 +575,18 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3, | |||||
| const TensorLayout& p4) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
| * algorithm by heuristic. | * algorithm by heuristic. | ||||
| @@ -562,7 +612,11 @@ protected: | |||||
| ~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4) = 0; | |||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
| const TensorLayout& p4) = 0; | const TensorLayout& p4) = 0; | ||||
| @@ -604,6 +658,18 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7) { | |||||
| std::vector<AlgorithmInfo> ret; | |||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||||
| ret.emplace_back(algo->info()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm information which indicate the | * \brief Returns the best algorithm information which indicate the | ||||
| * algorithm by heuristic. | * algorithm by heuristic. | ||||
| @@ -629,7 +695,12 @@ protected: | |||||
| ~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7) = 0; | |||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
| const TensorLayout& p4, const TensorLayout& p5, | const TensorLayout& p4, const TensorLayout& p5, | ||||
| @@ -172,9 +172,14 @@ std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | |||||
| ret.push_back(i); | ret.push_back(i); | ||||
| } | } | ||||
| } | } | ||||
| megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm"); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst) { | |||||
| auto ret_safe = get_all_algorithms(src,dst); | |||||
| megdnn_assert(!ret_safe.empty(), "no usable pooling fwd algorithm"); | |||||
| return ret_safe; | |||||
| } | |||||
| Algorithm* PoolingImpl::get_algorithm_heuristic( | Algorithm* PoolingImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| @@ -131,6 +131,8 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| @@ -100,10 +100,16 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( | |||||
| ret.push_back(i); | ret.push_back(i); | ||||
| } | } | ||||
| } | } | ||||
| megdnn_assert(!ret.empty(), "no algorithm for %s", | |||||
| args.to_string().c_str()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| template <class Opr> | |||||
| std::vector<typename Opr::Algorithm*> get_all_algorithms_safe( | |||||
| const typename Opr::AlgoBase::SizeArgs& args) { | |||||
| auto ret_safe = get_all_algorithms<Opr>(args); | |||||
| megdnn_assert(!ret_safe.empty(), "no algorithm for %s", | |||||
| args.to_string().c_str()); | |||||
| return ret_safe; | |||||
| } | |||||
| /*! | /*! | ||||
| * \brief a helper function to get an algorithm match attribute. If require a | * \brief a helper function to get an algorithm match attribute. If require a | ||||
| @@ -51,6 +51,15 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | ||||
| return megdnn::get_all_algorithms<BatchConvBiasForwardImpl>(args); | return megdnn::get_all_algorithms<BatchConvBiasForwardImpl>(args); | ||||
| } | } | ||||
| std::vector<BatchConvBiasForwardImpl::Algorithm*> | |||||
| BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst) { | |||||
| AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | |||||
| return megdnn::get_all_algorithms_safe<BatchConvBiasForwardImpl>(args); | |||||
| } | |||||
| size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| @@ -42,6 +42,10 @@ protected: | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| @@ -51,6 +51,12 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||||
| auto ret_safe = get_all_algorithms(A,B,C); | |||||
| megdnn_assert(!ret_safe.empty(), "no usable batchedmatrixmulForward fwd algorithm"); | |||||
| return ret_safe; | |||||
| } | |||||
| Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| @@ -45,6 +45,9 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -49,6 +49,16 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| {this, src, filter, bias, z, dst}); | {this, src, filter, bias, z, dst}); | ||||
| } | } | ||||
| std::vector<ConvBiasForward::Algorithm*> | |||||
| ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst) { | |||||
| return megdnn::get_all_algorithms_safe<ConvBiasForwardImpl>( | |||||
| {this, src, filter, bias, z, dst}); | |||||
| } | |||||
| ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| @@ -84,6 +84,10 @@ public: | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| @@ -53,6 +53,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args); | return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args); | ||||
| } | } | ||||
| std::vector<ConvolutionForwardImpl::Algorithm*> | |||||
| ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst) { | |||||
| AlgoBase::SizeArgs args{this, src, filter, dst}; | |||||
| return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>(args); | |||||
| } | |||||
| size_t ConvolutionForwardImpl::get_workspace_in_bytes( | size_t ConvolutionForwardImpl::get_workspace_in_bytes( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| @@ -97,6 +105,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
| {this, filter, diff, grad}); | {this, filter, diff, grad}); | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
| ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>( | |||||
| {this, filter, diff, grad}); | |||||
| } | |||||
| ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| @@ -222,6 +238,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
| {this, src, diff, grad}); | {this, src, diff, grad}); | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardFilterImpl::Algorithm*> | |||||
| ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>( | |||||
| {this, src, diff, grad}); | |||||
| } | |||||
| ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| @@ -59,6 +59,10 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| @@ -111,6 +115,10 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -159,6 +167,10 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -108,6 +108,14 @@ Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| {this, src, filter, dst}); | {this, src, filter, dst}); | ||||
| } | } | ||||
| std::vector<Convolution3DForwardImpl::Algorithm*> | |||||
| Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst) { | |||||
| return megdnn::get_all_algorithms_safe<Convolution3DForwardImpl>( | |||||
| {this, src, filter, dst}); | |||||
| } | |||||
| size_t Convolution3DForwardImpl::get_workspace_in_bytes( | size_t Convolution3DForwardImpl::get_workspace_in_bytes( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| @@ -146,6 +154,14 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
| {this, filter, diff, grad}); | {this, filter, diff, grad}); | ||||
| } | } | ||||
| std::vector<Convolution3DBackwardDataImpl::Algorithm*> | |||||
| Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<Convolution3DBackwardDataImpl>( | |||||
| {this, filter, diff, grad}); | |||||
| } | |||||
| Convolution3DBackwardDataImpl::Algorithm* | Convolution3DBackwardDataImpl::Algorithm* | ||||
| Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| @@ -226,6 +242,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
| {this, src, diff, grad}); | {this, src, diff, grad}); | ||||
| } | } | ||||
| std::vector<Convolution3DBackwardFilterImpl::Algorithm*> | |||||
| Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<Convolution3DBackwardFilterImpl>( | |||||
| {this, src, diff, grad}); | |||||
| } | |||||
| Convolution3DBackwardFilterImpl::Algorithm* | Convolution3DBackwardFilterImpl::Algorithm* | ||||
| Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| @@ -39,6 +39,9 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| @@ -72,6 +75,9 @@ public: | |||||
| protected: | protected: | ||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| @@ -109,6 +115,9 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -51,6 +51,15 @@ std::vector<AlgoFwd*> Fwd::get_all_algorithms(const TensorLayout& /* im */, | |||||
| return algos; | return algos; | ||||
| } | } | ||||
| std::vector<AlgoFwd*> Fwd::get_all_algorithms_safe(const TensorLayout& im, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& dst) { | |||||
| auto ret_safe = Fwd::get_all_algorithms(im,filter,offset,mask,dst); | |||||
| megdnn_assert(!ret_safe.empty(), "no usable deformable_conv fwd algorithm"); | |||||
| return ret_safe; | |||||
| } | |||||
| AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | ||||
| const TensorLayout& filter, | const TensorLayout& filter, | ||||
| @@ -115,6 +124,14 @@ std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms(const TensorLayout& /* im */ | |||||
| return algos; | return algos; | ||||
| } | } | ||||
| std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms_safe(const TensorLayout& im, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& filter_grad) { | |||||
| auto ret_safe = BwdFlt::get_all_algorithms(im,offset,mask,out_grad,filter_grad); | |||||
| megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd filter algorithm"); | |||||
| return ret_safe; | |||||
| } | |||||
| AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | ||||
| const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
| const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
| @@ -181,6 +198,14 @@ std::vector<AlgoBwdData*> BwdData::get_all_algorithms( | |||||
| algos.push_back(static_cast<AlgoBwdData*>(i)); | algos.push_back(static_cast<AlgoBwdData*>(i)); | ||||
| return algos; | return algos; | ||||
| } | } | ||||
| std::vector<AlgoBwdData*> BwdData::get_all_algorithms_safe( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, | |||||
| const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad ) { | |||||
| auto ret_safe = BwdData::get_all_algorithms(im,filter,offset,mask,out_grad,im_grad,offset_grad,mask_grad); | |||||
| megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd data algorithm"); | |||||
| return ret_safe; | |||||
| } | |||||
| AlgoBwdData* BwdData::get_algorithm_heuristic( | AlgoBwdData* BwdData::get_algorithm_heuristic( | ||||
| const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
| @@ -54,6 +54,10 @@ protected: | |||||
| const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
| const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
| @@ -105,6 +109,10 @@ protected: | |||||
| const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
| const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
| const TensorLayout& filter_grad) override; | const TensorLayout& filter_grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& im, const TensorLayout& offset, | |||||
| const TensorLayout& mask, const TensorLayout& out_grad, | |||||
| const TensorLayout& filter_grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
| @@ -161,6 +169,13 @@ protected: | |||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
| const TensorLayout& offset_grad, | const TensorLayout& offset_grad, | ||||
| const TensorLayout& mask_grad) override; | const TensorLayout& mask_grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||||
| const TensorLayout& offset_grad, | |||||
| const TensorLayout& mask_grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& im, const TensorLayout& filter, | const TensorLayout& im, const TensorLayout& filter, | ||||
| @@ -47,7 +47,6 @@ LocalShareForwardImpl::get_algorithm_heuristic( | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | Algorithm::attribute_str(positive_attr).c_str(), | ||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | args.to_string().c_str(), workspace_limit_in_bytes)); | ||||
| } | } | ||||
| std::vector<LocalShareForwardImpl::Algorithm*> | std::vector<LocalShareForwardImpl::Algorithm*> | ||||
| LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | ||||
| const TensorLayout& filter, | const TensorLayout& filter, | ||||
| @@ -56,6 +55,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| return megdnn::get_all_algorithms<LocalShareForwardImpl>(args); | return megdnn::get_all_algorithms<LocalShareForwardImpl>(args); | ||||
| } | } | ||||
| std::vector<LocalShareForwardImpl::Algorithm*> | |||||
| LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst) { | |||||
| AlgoBase::SizeArgs args{this, src, filter, dst}; | |||||
| return megdnn::get_all_algorithms_safe<LocalShareForwardImpl>(args); | |||||
| } | |||||
| size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | ||||
| const TensorLayout& filter, | const TensorLayout& filter, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| @@ -109,6 +116,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
| return megdnn::get_all_algorithms<LocalShareBackwardDataImpl>(args); | return megdnn::get_all_algorithms<LocalShareBackwardDataImpl>(args); | ||||
| } | } | ||||
| std::vector<LocalShareBackwardDataImpl::Algorithm*> | |||||
| LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| AlgoBase::SizeArgs args{this, filter, diff, grad}; | |||||
| return megdnn::get_all_algorithms_safe<LocalShareBackwardDataImpl>(args); | |||||
| } | |||||
| size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, | size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad) { | const TensorLayout& grad) { | ||||
| @@ -162,6 +177,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
| return megdnn::get_all_algorithms<LocalShareBackwardFilterImpl>(args); | return megdnn::get_all_algorithms<LocalShareBackwardFilterImpl>(args); | ||||
| } | } | ||||
| std::vector<LocalShareBackwardFilterImpl::Algorithm*> | |||||
| LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| AlgoBase::SizeArgs args{this, src, diff, grad}; | |||||
| return megdnn::get_all_algorithms_safe<LocalShareBackwardFilterImpl>(args); | |||||
| } | |||||
| size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, | size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad) { | const TensorLayout& grad) { | ||||
| @@ -37,6 +37,9 @@ public: | |||||
| protected: | protected: | ||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| @@ -72,6 +75,9 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -105,6 +111,9 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -28,6 +28,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | ||||
| } | } | ||||
| std::vector<MatrixMulForwardImpl::Algorithm*> | |||||
| MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args); | |||||
| } | |||||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -60,6 +60,10 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -33,6 +33,11 @@ PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | ||||
| } | } | ||||
| std::vector<PoolingForwardImpl::Algorithm*> | |||||
| PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& dst) { | |||||
| return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst}); | |||||
| } | |||||
| PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| @@ -77,6 +82,15 @@ PoolingBackwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| {this, src, dst, diff, grad}); | {this, src, dst, diff, grad}); | ||||
| } | } | ||||
| std::vector<PoolingBackwardImpl::Algorithm*> | |||||
| PoolingBackwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& dst, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>( | |||||
| {this, src, dst, diff, grad}); | |||||
| } | |||||
| PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
| @@ -55,6 +55,8 @@ public: | |||||
| protected: | protected: | ||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -99,6 +101,9 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| const TensorLayout& diff, const TensorLayout& grad) override; | const TensorLayout& diff, const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
| @@ -26,6 +26,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | ||||
| } | } | ||||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||||
| BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args); | |||||
| } | |||||
| BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| @@ -35,6 +35,9 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| @@ -279,11 +279,18 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms( | |||||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | ||||
| auto ret = get_all_algorithms_with_ncb(fparam); | auto ret = get_all_algorithms_with_ncb(fparam); | ||||
| if (ret.empty()) { | if (ret.empty()) { | ||||
| return naive::ConvBiasForwardImpl::get_all_algorithms(src, filter, bias, | |||||
| return naive::ConvBiasForwardImpl::get_all_algorithms_safe(src, filter, bias, | |||||
| z, dst); | z, dst); | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) { | |||||
| auto ret_safe = ConvBiasImpl::get_all_algorithms(src,filter,bias,z,dst); | |||||
| return ret_safe; | |||||
| } | |||||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| @@ -87,6 +87,10 @@ public: | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) override; | |||||
| //! implemented by get_algorithm_heuristic_with_ncb() | //! implemented by get_algorithm_heuristic_with_ncb() | ||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| @@ -198,12 +198,19 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms( | |||||
| auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | ||||
| auto ret = get_all_algorithms_with_ncb(fparam); | auto ret = get_all_algorithms_with_ncb(fparam); | ||||
| if (ret.empty()) { | if (ret.empty()) { | ||||
| return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter, | |||||
| return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter, | |||||
| dst); | dst); | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) { | |||||
| auto ret_safe = ConvolutionImpl::get_all_algorithms(src,filter,dst); | |||||
| return ret_safe; | |||||
| } | |||||
| ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| @@ -536,10 +543,19 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
| } | } | ||||
| auto fparam = make_ncb_kern_size_param(filter, diff, grad); | auto fparam = make_ncb_kern_size_param(filter, diff, grad); | ||||
| auto ret = get_all_algorithms_with_ncb(fparam); | auto ret = get_all_algorithms_with_ncb(fparam); | ||||
| megdnn_assert(!ret.empty(), "no usable conv fwd algorithm"); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
| ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter,diff,grad); | |||||
| megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm"); | |||||
| return ret_safe; | |||||
| } | |||||
| ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| @@ -85,6 +85,10 @@ public: | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| //! implemented by get_algorithm_heuristic_with_ncb() | //! implemented by get_algorithm_heuristic_with_ncb() | ||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| @@ -326,6 +330,9 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -96,6 +96,13 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | |||||
| return gemv_algos; | return gemv_algos; | ||||
| } | } | ||||
| std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||||
| auto gemv_algos_safe = get_all_algorithms(A,B,C); | |||||
| megdnn_assert(!gemv_algos_safe.empty(), "no usable MatrixMul fwd algorithm"); | |||||
| return gemv_algos_safe; | |||||
| } | |||||
| MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | ||||
| const AlgorithmDesc& desc) { | const AlgorithmDesc& desc) { | ||||
| if (!desc.valid()) { | if (!desc.valid()) { | ||||
| @@ -270,6 +270,10 @@ protected: | |||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -128,6 +128,16 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
| ->default_batch_conv_bias_fwd_algo()}; | ->default_batch_conv_bias_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<BatchConvBiasForward::Algorithm*> | |||||
| BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle()) | |||||
| ->default_batch_conv_bias_fwd_algo()}; | |||||
| } | |||||
| BatchConvBiasForward::Algorithm* | BatchConvBiasForward::Algorithm* | ||||
| BatchConvBiasForwardImpl::get_algorithm_heuristic( | BatchConvBiasForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
| @@ -30,6 +30,11 @@ public: | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| @@ -63,7 +63,6 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||||
| } | } | ||||
| } | } | ||||
| std::vector<BatchedMatrixMulForward::Algorithm*> | std::vector<BatchedMatrixMulForward::Algorithm*> | ||||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | ||||
| const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
| @@ -71,6 +70,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
| return {static_cast<HandleImpl*>(handle()) | return {static_cast<HandleImpl*>(handle()) | ||||
| ->default_batched_matmul_fwd_algo()}; | ->default_batched_matmul_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<BatchedMatrixMulForward::Algorithm*> | |||||
| BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) { | |||||
| return {static_cast<HandleImpl*>(handle()) | |||||
| ->default_batched_matmul_fwd_algo()}; | |||||
| } | |||||
| BatchedMatrixMulForward::Algorithm* | BatchedMatrixMulForward::Algorithm* | ||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| @@ -27,6 +27,9 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| @@ -321,6 +321,15 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
| return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<ConvBiasForward::Algorithm*> | |||||
| ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | |||||
| } | |||||
| ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
| const TensorLayout& /* bias */, const TensorLayout& /* z */, | const TensorLayout& /* bias */, const TensorLayout& /* z */, | ||||
| @@ -31,6 +31,11 @@ public: | |||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| @@ -287,6 +287,13 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, | |||||
| return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<ConvolutionForward::Algorithm *> | |||||
| ConvolutionForwardImpl:: get_all_algorithms_safe(const TensorLayout &, | |||||
| const TensorLayout &, const TensorLayout &) | |||||
| { | |||||
| return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | |||||
| } | |||||
| ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
| const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | ||||
| @@ -313,6 +320,13 @@ ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||||
| return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardData::Algorithm *> | |||||
| ConvolutionBackwardDataImpl:: get_all_algorithms_safe(const TensorLayout &, | |||||
| const TensorLayout &, const TensorLayout &) | |||||
| { | |||||
| return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | |||||
| } | |||||
| ConvolutionBackwardData::Algorithm* | ConvolutionBackwardData::Algorithm* | ||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
| @@ -341,6 +355,13 @@ ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||||
| return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardFilter::Algorithm *> | |||||
| ConvolutionBackwardFilterImpl:: get_all_algorithms_safe(const TensorLayout &, | |||||
| const TensorLayout &, const TensorLayout &) | |||||
| { | |||||
| return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | |||||
| } | |||||
| ConvolutionBackwardFilter::Algorithm* | ConvolutionBackwardFilter::Algorithm* | ||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| @@ -25,6 +25,9 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | ||||
| const TensorLayout &filter, | const TensorLayout &filter, | ||||
| const TensorLayout &dst) override; | const TensorLayout &dst) override; | ||||
| std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src, | |||||
| const TensorLayout &filter, | |||||
| const TensorLayout &dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| @@ -67,6 +70,9 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | ||||
| const TensorLayout &diff, | const TensorLayout &diff, | ||||
| const TensorLayout &grad) override; | const TensorLayout &grad) override; | ||||
| std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &filter, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -90,6 +96,9 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | ||||
| const TensorLayout &diff, | const TensorLayout &diff, | ||||
| const TensorLayout &grad) override; | const TensorLayout &grad) override; | ||||
| std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -108,13 +108,18 @@ void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||||
| megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
| } | } | ||||
| std::vector<Convolution3DForward::Algorithm*> | std::vector<Convolution3DForward::Algorithm*> | ||||
| Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | ||||
| const TensorLayout&, | const TensorLayout&, | ||||
| const TensorLayout&) { | const TensorLayout&) { | ||||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<Convolution3DForward::Algorithm*> | |||||
| Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | |||||
| } | |||||
| Convolution3DForward::Algorithm* | Convolution3DForward::Algorithm* | ||||
| Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
| @@ -143,6 +148,13 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | ||||
| } | } | ||||
| std::vector<Convolution3DBackwardData::Algorithm*> | |||||
| Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | |||||
| } | |||||
| Convolution3DBackwardData::Algorithm* | Convolution3DBackwardData::Algorithm* | ||||
| Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
| @@ -172,6 +184,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||||
| ->default_conv3d_bwd_filter_algo()}; | ->default_conv3d_bwd_filter_algo()}; | ||||
| } | } | ||||
| std::vector<Convolution3DBackwardFilter::Algorithm*> | |||||
| Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle()) | |||||
| ->default_conv3d_bwd_filter_algo()}; | |||||
| } | |||||
| Convolution3DBackwardFilter::Algorithm* | Convolution3DBackwardFilter::Algorithm* | ||||
| Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| @@ -22,6 +22,9 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| @@ -44,6 +47,9 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -66,6 +72,9 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -25,6 +25,12 @@ public: | |||||
| const TensorLayout& /* dst */) override { | const TensorLayout& /* dst */) override { | ||||
| return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
| }; | }; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||||
| const TensorLayout& /* dst */) override { | |||||
| return std::vector<Algorithm*>(); | |||||
| }; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
| @@ -67,6 +73,13 @@ public: | |||||
| const TensorLayout& /* filter_grad */) override { | const TensorLayout& /* filter_grad */) override { | ||||
| return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
| }; | }; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /* im */, const TensorLayout& /* offset */, | |||||
| const TensorLayout& /* mask */, const TensorLayout& /* out_grad */, | |||||
| const TensorLayout& /* filter_grad */) override { | |||||
| return std::vector<Algorithm*>(); | |||||
| }; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /* im */, const TensorLayout& /* offset */, | const TensorLayout& /* im */, const TensorLayout& /* offset */, | ||||
| @@ -112,6 +125,16 @@ public: | |||||
| return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
| }; | }; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||||
| const TensorLayout& /* out_grad */, | |||||
| const TensorLayout& /* im_grad */, | |||||
| const TensorLayout& /* offset_grad */, | |||||
| const TensorLayout& /* mask_grad */) override { | |||||
| return std::vector<Algorithm*>(); | |||||
| }; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /* im */, const TensorLayout& /* filter */, | const TensorLayout& /* im */, const TensorLayout& /* filter */, | ||||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | const TensorLayout& /* offset */, const TensorLayout& /* mask */, | ||||
| @@ -159,6 +159,13 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
| return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<LocalShareForward::Algorithm*> | |||||
| LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | |||||
| } | |||||
| LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| @@ -187,6 +194,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||||
| ->default_local_share_bwd_data_algo()}; | ->default_local_share_bwd_data_algo()}; | ||||
| } | } | ||||
| std::vector<LocalShareBackwardData::Algorithm*> | |||||
| LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle()) | |||||
| ->default_local_share_bwd_data_algo()}; | |||||
| } | |||||
| LocalShareBackwardData::Algorithm* | LocalShareBackwardData::Algorithm* | ||||
| LocalShareBackwardDataImpl::get_algorithm_heuristic( | LocalShareBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
| @@ -216,6 +231,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||||
| ->default_local_share_bwd_filter_algo()}; | ->default_local_share_bwd_filter_algo()}; | ||||
| } | } | ||||
| std::vector<LocalShareBackwardFilter::Algorithm*> | |||||
| LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle()) | |||||
| ->default_local_share_bwd_filter_algo()}; | |||||
| } | |||||
| LocalShareBackwardFilter::Algorithm* | LocalShareBackwardFilter::Algorithm* | ||||
| LocalShareBackwardFilterImpl::get_algorithm_heuristic( | LocalShareBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| @@ -30,6 +30,10 @@ public: | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | ||||
| const TensorLayout& /*dst*/) override; | const TensorLayout& /*dst*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | |||||
| const TensorLayout& /*dst*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | ||||
| const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, | ||||
| @@ -55,6 +59,10 @@ public: | |||||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | ||||
| const TensorLayout& /*grad*/) override; | const TensorLayout& /*grad*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | |||||
| const TensorLayout& /*grad*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | ||||
| const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | ||||
| @@ -75,11 +83,14 @@ public: | |||||
| const TensorLayout&) override { | const TensorLayout&) override { | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | ||||
| const TensorLayout& /*grad*/) override; | const TensorLayout& /*grad*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | |||||
| const TensorLayout& /*grad*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | ||||
| const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | ||||
| @@ -88,6 +88,13 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
| return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<MatrixMulForward::Algorithm*> | |||||
| MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||||
| } | |||||
| MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
| @@ -29,6 +29,10 @@ public: | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
| @@ -603,6 +603,10 @@ std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms( | |||||
| const TensorLayout&, const TensorLayout&) { | const TensorLayout&, const TensorLayout&) { | ||||
| return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | ||||
| } | } | ||||
| std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms_safe( | |||||
| const TensorLayout&, const TensorLayout&) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | |||||
| } | |||||
| Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | ||||
| @@ -626,6 +630,11 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||||
| const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | ||||
| return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | ||||
| } | } | ||||
| std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | |||||
| const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | |||||
| } | |||||
| Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | ||||
| @@ -35,6 +35,8 @@ class PoolingForwardImpl: public PoolingForward { | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| @@ -60,6 +62,9 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| const TensorLayout& diff, const TensorLayout& grad) override; | const TensorLayout& diff, const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| @@ -29,6 +29,14 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | ||||
| } | } | ||||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||||
| BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args); | |||||
| } | |||||
| BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| @@ -35,6 +35,9 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| @@ -109,6 +109,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||||
| {this, src, filter, dst}); | {this, src, filter, dst}); | ||||
| } | } | ||||
| std::vector<ConvolutionForwardImpl::Algorithm*> | |||||
| ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst) { | |||||
| return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>( | |||||
| {this, src, filter, dst}); | |||||
| } | |||||
| size_t ConvolutionForwardImpl::get_workspace_in_bytes( | size_t ConvolutionForwardImpl::get_workspace_in_bytes( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, const PreprocessedFilter*) { | const TensorLayout& dst, const PreprocessedFilter*) { | ||||
| @@ -162,6 +170,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||||
| {this, filter, diff, grad}); | {this, filter, diff, grad}); | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
| ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>( | |||||
| {this, filter, diff, grad}); | |||||
| } | |||||
| ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| @@ -243,6 +259,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||||
| {this, src, diff, grad}); | {this, src, diff, grad}); | ||||
| } | } | ||||
| std::vector<ConvolutionBackwardFilterImpl::Algorithm*> | |||||
| ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>( | |||||
| {this, src, diff, grad}); | |||||
| } | |||||
| ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| @@ -74,6 +74,9 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| @@ -123,6 +126,9 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -172,6 +178,9 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| @@ -27,6 +27,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | ||||
| } | } | ||||
| std::vector<MatrixMulForwardImpl::Algorithm*> | |||||
| MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||||
| return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args); | |||||
| } | |||||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -36,6 +36,10 @@ private: | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
| @@ -25,12 +25,16 @@ size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| const char* PoolingForwardImpl::get_algorithm_set_name() const { | const char* PoolingForwardImpl::get_algorithm_set_name() const { | ||||
| return "ROCM_POOLING_FORWARD"; | return "ROCM_POOLING_FORWARD"; | ||||
| } | } | ||||
| std::vector<PoolingForwardImpl::Algorithm*> | std::vector<PoolingForwardImpl::Algorithm*> | ||||
| PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | ||||
| const TensorLayout& dst) { | const TensorLayout& dst) { | ||||
| return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | ||||
| } | } | ||||
| std::vector<PoolingForwardImpl::Algorithm*> | |||||
| PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||||
| const TensorLayout& dst) { | |||||
| return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst}); | |||||
| } | |||||
| PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| @@ -82,6 +86,13 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||||
| {this, src, dst, diff, grad}); | {this, src, dst, diff, grad}); | ||||
| } | } | ||||
| std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad) { | |||||
| return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>( | |||||
| {this, src, dst, diff, grad}); | |||||
| } | |||||
| Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
| @@ -46,6 +46,8 @@ class PoolingForwardImpl final: public PoolingForward { | |||||
| protected: | protected: | ||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -93,6 +95,9 @@ class PoolingBackwardImpl final: public PoolingBackward { | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| const TensorLayout& diff, const TensorLayout& grad) override; | const TensorLayout& diff, const TensorLayout& grad) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
| @@ -74,11 +74,14 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| return fallback_worksapce; | return fallback_worksapce; | ||||
| } | } | ||||
| } | } | ||||
| std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
| return megdnn::get_all_algorithms<PoolingImpl>({this, src, dst}); | return megdnn::get_all_algorithms<PoolingImpl>({this, src, dst}); | ||||
| } | } | ||||
| std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst) { | |||||
| return megdnn::get_all_algorithms_safe<PoolingImpl>({this, src, dst}); | |||||
| } | |||||
| Algorithm* PoolingImpl::get_algorithm_heuristic( | Algorithm* PoolingImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| @@ -63,6 +63,8 @@ public: | |||||
| protected: | protected: | ||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& dst) override; | const TensorLayout& src, const TensorLayout& dst) override; | ||||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||||
| const TensorLayout& src, const TensorLayout& dst) override; | |||||
| Algorithm* get_algorithm_heuristic( | Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& dst, | const TensorLayout& src, const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | ||||
| @@ -164,7 +164,7 @@ public: | |||||
| } | } | ||||
| std::vector<Algorithm::Info::Desc> ret; | std::vector<Algorithm::Info::Desc> ret; | ||||
| megdnn_assert(layouts.size() == OprTrait<Opr>::arity); | megdnn_assert(layouts.size() == OprTrait<Opr>::arity); | ||||
| auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | |||||
| auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe( | |||||
| opr, layouts); | opr, layouts); | ||||
| for (auto algo_info : vec) { | for (auto algo_info : vec) { | ||||
| if (!(algo_info.attribute & | if (!(algo_info.attribute & | ||||
| @@ -377,7 +377,7 @@ float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts, | |||||
| auto opr = benchmark.opr(); | auto opr = benchmark.opr(); | ||||
| opr->param() = benchmark.param(); | opr->param() = benchmark.param(); | ||||
| proxy.deduce_layout(opr, layouts); | proxy.deduce_layout(opr, layouts); | ||||
| auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info(opr, layouts); | |||||
| auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info_safe(opr, layouts); | |||||
| float min_used = std::numeric_limits<float>::max(); | float min_used = std::numeric_limits<float>::max(); | ||||
| bool execed = false; | bool execed = false; | ||||
| for (auto i : algos) { | for (auto i : algos) { | ||||
| @@ -514,7 +514,7 @@ struct ExecutionPolicyAlgoName { | |||||
| * \brief a callable to check that given algorithm is used for heuristic | * \brief a callable to check that given algorithm is used for heuristic | ||||
| * \param require_algo if its value is true, then requires | * \param require_algo if its value is true, then requires | ||||
| * get_algorithm_heuristic() to return the expected algo; otherwise the | * get_algorithm_heuristic() to return the expected algo; otherwise the | ||||
| * expected algo must exist in get_all_algorithms() and it would be set to | |||||
| * expected algo must exist in get_all_algorithms_safe() and it would be set to | |||||
| * be used | * be used | ||||
| */ | */ | ||||
| template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> | template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> | ||||
| @@ -536,7 +536,7 @@ public: | |||||
| opr->param() = | opr->param() = | ||||
| Algorithm::deserialize_read_pod<typename Opr::Param>(param); | Algorithm::deserialize_read_pod<typename Opr::Param>(param); | ||||
| for (auto algo_info : | for (auto algo_info : | ||||
| AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | |||||
| AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe( | |||||
| opr.get(), layouts)) { | opr.get(), layouts)) { | ||||
| if (std::regex_match( | if (std::regex_match( | ||||
| algo_info.desc.name, | algo_info.desc.name, | ||||
| @@ -695,7 +695,7 @@ Checker<Convolution> checker(handle); | |||||
| float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); | float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); | ||||
| UniformFloatRNG rng(scale, 2 * scale); | UniformFloatRNG rng(scale, 2 * scale); | ||||
| checker.set_rng(0, &rng).set_rng(1, &rng); | checker.set_rng(0, &rng).set_rng(1, &rng); | ||||
| for (auto algo : opr->get_all_algorithms_info(ily, fly, oly)) { | |||||
| for (auto algo : opr->get_all_algorithms_info_safe(ily, fly, oly)) { | |||||
| used_algos.insert(algo.desc); | used_algos.insert(algo.desc); | ||||
| opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
| @@ -720,7 +720,7 @@ Checker<Convolution> checker(handle); | |||||
| opr->param() = param; | opr->param() = param; | ||||
| std::string param_str; | std::string param_str; | ||||
| Algorithm::serialize_write_pod(opr->param(), param_str); | Algorithm::serialize_write_pod(opr->param(), param_str); | ||||
| for (auto algo : opr->get_all_algorithms_info(fly, oly, ily)) { | |||||
| for (auto algo : opr->get_all_algorithms_info_safe(fly, oly, ily)) { | |||||
| used_algos_bwd_data.insert(algo.desc); | used_algos_bwd_data.insert(algo.desc); | ||||
| opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
| construct_sub_execution_policy_heuristic< | construct_sub_execution_policy_heuristic< | ||||
| @@ -747,7 +747,7 @@ Checker<Convolution> checker(handle); | |||||
| opr->param() = param; | opr->param() = param; | ||||
| std::string param_str; | std::string param_str; | ||||
| Algorithm::serialize_write_pod(opr->param(), param_str); | Algorithm::serialize_write_pod(opr->param(), param_str); | ||||
| for (auto algo : opr->get_all_algorithms_info(ily, oly, fly)) { | |||||
| for (auto algo : opr->get_all_algorithms_info_safe(ily, oly, fly)) { | |||||
| used_algos_bwd_flt.insert(algo.desc); | used_algos_bwd_flt.insert(algo.desc); | ||||
| opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
| construct_sub_execution_policy_heuristic< | construct_sub_execution_policy_heuristic< | ||||
| @@ -25,9 +25,9 @@ struct AlgoProxy; | |||||
| template <typename Opr> \ | template <typename Opr> \ | ||||
| struct AlgoProxy<Opr, arity> { \ | struct AlgoProxy<Opr, arity> { \ | ||||
| static std::vector<typename Opr::AlgorithmInfo> \ | static std::vector<typename Opr::AlgorithmInfo> \ | ||||
| get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
| get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
| megdnn_assert(layouts.size() == arity); \ | megdnn_assert(layouts.size() == arity); \ | ||||
| return opr->get_all_algorithms_info(LAYOUTS); \ | |||||
| return opr->get_all_algorithms_info_safe(LAYOUTS); \ | |||||
| } \ | } \ | ||||
| static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | ||||
| Opr* opr, const TensorLayoutArray& layouts) { \ | Opr* opr, const TensorLayoutArray& layouts) { \ | ||||
| @@ -80,9 +80,9 @@ DEF_ALGO_PROXY(8); | |||||
| template <> \ | template <> \ | ||||
| struct AlgoProxy<Opr, arity> { \ | struct AlgoProxy<Opr, arity> { \ | ||||
| static std::vector<typename Opr::AlgorithmInfo> \ | static std::vector<typename Opr::AlgorithmInfo> \ | ||||
| get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
| get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
| megdnn_assert(layouts.size() == arity); \ | megdnn_assert(layouts.size() == arity); \ | ||||
| return opr->get_all_algorithms_info(LAYOUTS); \ | |||||
| return opr->get_all_algorithms_info_safe(LAYOUTS); \ | |||||
| } \ | } \ | ||||
| static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | ||||
| Opr* opr, const TensorLayoutArray& layouts) { \ | Opr* opr, const TensorLayoutArray& layouts) { \ | ||||
| @@ -288,7 +288,7 @@ struct OprProxyProfilingBase | |||||
| Algorithm::deserialize_read_pod<typename Opr::Param>(param); | Algorithm::deserialize_read_pod<typename Opr::Param>(param); | ||||
| std::vector<Algorithm::SearchItem> ret; | std::vector<Algorithm::SearchItem> ret; | ||||
| for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info( | |||||
| for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe( | |||||
| opr.get(), layouts)) { | opr.get(), layouts)) { | ||||
| Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc); | Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc); | ||||
| std::vector<Algorithm::SearchItem>&& sub_items = | std::vector<Algorithm::SearchItem>&& sub_items = | ||||
| @@ -367,7 +367,7 @@ struct OprProxyProfilingBase | |||||
| megdnn_log("Find best algo %s in cache", algo->name()); | megdnn_log("Find best algo %s in cache", algo->name()); | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info( | |||||
| for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe( | |||||
| opr.get(), layouts)) { | opr.get(), layouts)) { | ||||
| //! construct execution_policy | //! construct execution_policy | ||||
| opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
| @@ -492,7 +492,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> { | |||||
| if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) { | if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) { | ||||
| size_t min_time = std::numeric_limits<size_t>::max(); | size_t min_time = std::numeric_limits<size_t>::max(); | ||||
| for (auto algo : | for (auto algo : | ||||
| AlgoProxy<Opr, arity>::get_all_algorithms_info(opr, layouts)) { | |||||
| AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) { | |||||
| opr->execution_policy().algo = algo.desc; | opr->execution_policy().algo = algo.desc; | ||||
| auto preprocess_tensors = | auto preprocess_tensors = | ||||
| @@ -84,7 +84,7 @@ void test_multibatchsize( | |||||
| auto opr_reference = handle_cuda->create_operator<MatrixMulForward>(); | auto opr_reference = handle_cuda->create_operator<MatrixMulForward>(); | ||||
| { | { | ||||
| opr_reference->execution_policy().algo.reset(); | opr_reference->execution_policy().algo.reset(); | ||||
| for (auto i : opr_reference->get_all_algorithms_info( | |||||
| for (auto i : opr_reference->get_all_algorithms_info_safe( | |||||
| A_tensor.layout(), B_tensor.layout(), | A_tensor.layout(), B_tensor.layout(), | ||||
| C_tensor.layout())) { | C_tensor.layout())) { | ||||
| if (std::regex_match( | if (std::regex_match( | ||||
| @@ -113,7 +113,7 @@ void test_multibatchsize( | |||||
| {{}, {}, C_tensor_prime.tensornd_host()}); | {{}, {}, C_tensor_prime.tensornd_host()}); | ||||
| { | { | ||||
| opr_reference->execution_policy().algo.reset(); | opr_reference->execution_policy().algo.reset(); | ||||
| for (auto i : opr_reference->get_all_algorithms_info( | |||||
| for (auto i : opr_reference->get_all_algorithms_info_safe( | |||||
| A_tensor_prime.layout(), B_tensor.layout(), | A_tensor_prime.layout(), B_tensor.layout(), | ||||
| C_tensor_batch.layout())) { | C_tensor_batch.layout())) { | ||||
| if (std::regex_match( | if (std::regex_match( | ||||
| @@ -1938,7 +1938,7 @@ typename megdnn::ExecutionPolicy try_find_any_weight_preprocess_algo( | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| } | } | ||||
| for (auto&& algo : dnn_op->get_all_algorithms_info( | |||||
| for (auto&& algo : dnn_op->get_all_algorithms_info_safe( | |||||
| std::forward<Args>(args)...)) { | std::forward<Args>(args)...)) { | ||||
| dnn_op->execution_policy().algo = algo.desc; | dnn_op->execution_policy().algo = algo.desc; | ||||
| auto layouts = dnn_op->deduce_preprocessed_filter_layout( | auto layouts = dnn_op->deduce_preprocessed_filter_layout( | ||||
| @@ -1972,7 +1972,7 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo( | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| } | } | ||||
| for (auto&& algo : dnn_op->get_all_algorithms_info( | |||||
| for (auto&& algo : dnn_op->get_all_algorithms_info_safe( | |||||
| std::forward<Args>(args)...)) { | std::forward<Args>(args)...)) { | ||||
| dnn_op->execution_policy().algo = algo.desc; | dnn_op->execution_policy().algo = algo.desc; | ||||
| auto layouts = dnn_op->deduce_preprocessed_filter_layout( | auto layouts = dnn_op->deduce_preprocessed_filter_layout( | ||||
| @@ -805,7 +805,7 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo> | |||||
| AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { | AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) | MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) | ||||
| auto heu = choose_by_heuristic(m_execution_policy.strategy); | auto heu = choose_by_heuristic(m_execution_policy.strategy); | ||||
| auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), | |||||
| auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info_safe(args...), | |||||
| m_fastrun_layouts); | m_fastrun_layouts); | ||||
| bool found = false; | bool found = false; | ||||
| for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
| @@ -2473,6 +2473,11 @@ public: | |||||
| std::vector<AlgorithmInfo>(const TensorLayout& p0, | std::vector<AlgorithmInfo>(const TensorLayout& p0, | ||||
| const TensorLayout& p1, | const TensorLayout& p1, | ||||
| const TensorLayout& p2)); | const TensorLayout& p2)); | ||||
| MOCK_METHOD3(get_all_algorithms_info_safe, | |||||
| std::vector<AlgorithmInfo>(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2)); | |||||
| MOCK_METHOD6(get_algorithm_info_heuristic, | MOCK_METHOD6(get_algorithm_info_heuristic, | ||||
| AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2, | const TensorLayout& p2, | ||||
| @@ -2484,6 +2489,11 @@ public: | |||||
| std::vector<Algorithm*>(const TensorLayout& p0, | std::vector<Algorithm*>(const TensorLayout& p0, | ||||
| const TensorLayout& p1, | const TensorLayout& p1, | ||||
| const TensorLayout& p2)); | const TensorLayout& p2)); | ||||
| MOCK_METHOD3(get_all_algorithms_safe, | |||||
| std::vector<Algorithm*>(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2)); | |||||
| MOCK_METHOD6(get_algorithm_heuristic, | MOCK_METHOD6(get_algorithm_heuristic, | ||||
| Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2, | const TensorLayout& p2, | ||||