GitOrigin-RevId: e3734e4531
tags/v1.6.0
| @@ -315,7 +315,7 @@ public: | |||
| /*! | |||
| * \brief get a string representation for current algorithm set; | |||
| * | |||
| * get_all_algorithms() may return different algorithms only if | |||
| * get_all_algorithms_safe() may return different algorithms only if | |||
| * algorithm set name differs. This is used for checking cache | |||
| * validity. | |||
| */ | |||
| @@ -354,6 +354,15 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| @@ -378,6 +387,8 @@ protected: | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1) = 0; | |||
| /** | |||
| * \brief Returns the best algorithm by heuristic. | |||
| @@ -412,6 +423,16 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| @@ -438,6 +459,9 @@ protected: | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2) = 0; | |||
| /** | |||
| * \brief Returns the best algorithm by heuristic. | |||
| @@ -463,7 +487,7 @@ public: | |||
| using AlgoAttribute = detail::Algorithm::Attribute; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3) { | |||
| @@ -474,6 +498,17 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| @@ -500,6 +535,9 @@ protected: | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | |||
| /** | |||
| * \brief Returns the best algorithm by heuristic. | |||
| @@ -537,6 +575,18 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3, | |||
| const TensorLayout& p4) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| @@ -562,7 +612,11 @@ protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4) = 0; | |||
| @@ -604,6 +658,18 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||
| ret.emplace_back(algo->info()); | |||
| } | |||
| return ret; | |||
| } | |||
| /** | |||
| * \brief Returns the best algorithm information which indicate the | |||
| * algorithm by heuristic. | |||
| @@ -629,7 +695,12 @@ protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| @@ -172,9 +172,14 @@ std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | |||
| ret.push_back(i); | |||
| } | |||
| } | |||
| megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm"); | |||
| return ret; | |||
| } | |||
| std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst) { | |||
| auto ret_safe = get_all_algorithms(src,dst); | |||
| megdnn_assert(!ret_safe.empty(), "no usable pooling fwd algorithm"); | |||
| return ret_safe; | |||
| } | |||
| Algorithm* PoolingImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| @@ -131,6 +131,8 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| @@ -100,10 +100,16 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( | |||
| ret.push_back(i); | |||
| } | |||
| } | |||
| megdnn_assert(!ret.empty(), "no algorithm for %s", | |||
| args.to_string().c_str()); | |||
| return ret; | |||
| } | |||
| template <class Opr> | |||
| std::vector<typename Opr::Algorithm*> get_all_algorithms_safe( | |||
| const typename Opr::AlgoBase::SizeArgs& args) { | |||
| auto ret_safe = get_all_algorithms<Opr>(args); | |||
| megdnn_assert(!ret_safe.empty(), "no algorithm for %s", | |||
| args.to_string().c_str()); | |||
| return ret_safe; | |||
| } | |||
| /*! | |||
| * \brief a helper function to get an algorithm match attribute. If require a | |||
| @@ -51,6 +51,15 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | |||
| return megdnn::get_all_algorithms<BatchConvBiasForwardImpl>(args); | |||
| } | |||
| std::vector<BatchConvBiasForwardImpl::Algorithm*> | |||
| BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& bias, | |||
| const TensorLayout& z, | |||
| const TensorLayout& dst) { | |||
| AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | |||
| return megdnn::get_all_algorithms_safe<BatchConvBiasForwardImpl>(args); | |||
| } | |||
| size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| @@ -42,6 +42,10 @@ protected: | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| @@ -51,6 +51,12 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms_safe( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||
| auto ret_safe = get_all_algorithms(A,B,C); | |||
| megdnn_assert(!ret_safe.empty(), "no usable batchedmatrixmulForward fwd algorithm"); | |||
| return ret_safe; | |||
| } | |||
| Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| @@ -45,6 +45,9 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -49,6 +49,16 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| {this, src, filter, bias, z, dst}); | |||
| } | |||
| std::vector<ConvBiasForward::Algorithm*> | |||
| ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& bias, | |||
| const TensorLayout& z, | |||
| const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms_safe<ConvBiasForwardImpl>( | |||
| {this, src, filter, bias, z, dst}); | |||
| } | |||
| ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| @@ -84,6 +84,10 @@ public: | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| @@ -53,6 +53,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| return megdnn::get_all_algorithms<ConvolutionForwardImpl>(args); | |||
| } | |||
| std::vector<ConvolutionForwardImpl::Algorithm*> | |||
| ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) { | |||
| AlgoBase::SizeArgs args{this, src, filter, dst}; | |||
| return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>(args); | |||
| } | |||
| size_t ConvolutionForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, | |||
| @@ -97,6 +105,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||
| {this, filter, diff, grad}); | |||
| } | |||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||
| ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>( | |||
| {this, filter, diff, grad}); | |||
| } | |||
| ConvolutionBackwardDataImpl::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| @@ -222,6 +238,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||
| {this, src, diff, grad}); | |||
| } | |||
| std::vector<ConvolutionBackwardFilterImpl::Algorithm*> | |||
| ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>( | |||
| {this, src, diff, grad}); | |||
| } | |||
| ConvolutionBackwardFilterImpl::Algorithm* | |||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| @@ -59,6 +59,10 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||
| @@ -111,6 +115,10 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -159,6 +167,10 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -108,6 +108,14 @@ Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| {this, src, filter, dst}); | |||
| } | |||
| std::vector<Convolution3DForwardImpl::Algorithm*> | |||
| Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms_safe<Convolution3DForwardImpl>( | |||
| {this, src, filter, dst}); | |||
| } | |||
| size_t Convolution3DForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) { | |||
| @@ -146,6 +154,14 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||
| {this, filter, diff, grad}); | |||
| } | |||
| std::vector<Convolution3DBackwardDataImpl::Algorithm*> | |||
| Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<Convolution3DBackwardDataImpl>( | |||
| {this, filter, diff, grad}); | |||
| } | |||
| Convolution3DBackwardDataImpl::Algorithm* | |||
| Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| @@ -226,6 +242,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||
| {this, src, diff, grad}); | |||
| } | |||
| std::vector<Convolution3DBackwardFilterImpl::Algorithm*> | |||
| Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<Convolution3DBackwardFilterImpl>( | |||
| {this, src, diff, grad}); | |||
| } | |||
| Convolution3DBackwardFilterImpl::Algorithm* | |||
| Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| @@ -39,6 +39,9 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||
| @@ -72,6 +75,9 @@ public: | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| @@ -109,6 +115,9 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -51,6 +51,15 @@ std::vector<AlgoFwd*> Fwd::get_all_algorithms(const TensorLayout& /* im */, | |||
| return algos; | |||
| } | |||
| std::vector<AlgoFwd*> Fwd::get_all_algorithms_safe(const TensorLayout& im, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& offset, | |||
| const TensorLayout& mask, | |||
| const TensorLayout& dst) { | |||
| auto ret_safe = Fwd::get_all_algorithms(im,filter,offset,mask,dst); | |||
| megdnn_assert(!ret_safe.empty(), "no usable deformable_conv fwd algorithm"); | |||
| return ret_safe; | |||
| } | |||
| AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | |||
| const TensorLayout& filter, | |||
| @@ -115,6 +124,14 @@ std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms(const TensorLayout& /* im */ | |||
| return algos; | |||
| } | |||
| std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms_safe(const TensorLayout& im, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& filter_grad) { | |||
| auto ret_safe = BwdFlt::get_all_algorithms(im,offset,mask,out_grad,filter_grad); | |||
| megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd filter algorithm"); | |||
| return ret_safe; | |||
| } | |||
| AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | |||
| const TensorLayout& im, const TensorLayout& offset, | |||
| const TensorLayout& mask, const TensorLayout& out_grad, | |||
| @@ -181,6 +198,14 @@ std::vector<AlgoBwdData*> BwdData::get_all_algorithms( | |||
| algos.push_back(static_cast<AlgoBwdData*>(i)); | |||
| return algos; | |||
| } | |||
| std::vector<AlgoBwdData*> BwdData::get_all_algorithms_safe( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, | |||
| const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad ) { | |||
| auto ret_safe = BwdData::get_all_algorithms(im,filter,offset,mask,out_grad,im_grad,offset_grad,mask_grad); | |||
| megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd data algorithm"); | |||
| return ret_safe; | |||
| } | |||
| AlgoBwdData* BwdData::get_algorithm_heuristic( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| @@ -54,6 +54,10 @@ protected: | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| @@ -105,6 +109,10 @@ protected: | |||
| const TensorLayout& im, const TensorLayout& offset, | |||
| const TensorLayout& mask, const TensorLayout& out_grad, | |||
| const TensorLayout& filter_grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& im, const TensorLayout& offset, | |||
| const TensorLayout& mask, const TensorLayout& out_grad, | |||
| const TensorLayout& filter_grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& im, const TensorLayout& offset, | |||
| @@ -161,6 +169,13 @@ protected: | |||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, | |||
| const TensorLayout& mask_grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| const TensorLayout& offset, const TensorLayout& mask, | |||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | |||
| const TensorLayout& offset_grad, | |||
| const TensorLayout& mask_grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& im, const TensorLayout& filter, | |||
| @@ -47,7 +47,6 @@ LocalShareForwardImpl::get_algorithm_heuristic( | |||
| Algorithm::attribute_str(positive_attr).c_str(), | |||
| args.to_string().c_str(), workspace_limit_in_bytes)); | |||
| } | |||
| std::vector<LocalShareForwardImpl::Algorithm*> | |||
| LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| @@ -56,6 +55,14 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| return megdnn::get_all_algorithms<LocalShareForwardImpl>(args); | |||
| } | |||
| std::vector<LocalShareForwardImpl::Algorithm*> | |||
| LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) { | |||
| AlgoBase::SizeArgs args{this, src, filter, dst}; | |||
| return megdnn::get_all_algorithms_safe<LocalShareForwardImpl>(args); | |||
| } | |||
| size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) { | |||
| @@ -109,6 +116,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||
| return megdnn::get_all_algorithms<LocalShareBackwardDataImpl>(args); | |||
| } | |||
| std::vector<LocalShareBackwardDataImpl::Algorithm*> | |||
| LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| AlgoBase::SizeArgs args{this, filter, diff, grad}; | |||
| return megdnn::get_all_algorithms_safe<LocalShareBackwardDataImpl>(args); | |||
| } | |||
| size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| @@ -162,6 +177,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||
| return megdnn::get_all_algorithms<LocalShareBackwardFilterImpl>(args); | |||
| } | |||
| std::vector<LocalShareBackwardFilterImpl::Algorithm*> | |||
| LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| AlgoBase::SizeArgs args{this, src, diff, grad}; | |||
| return megdnn::get_all_algorithms_safe<LocalShareBackwardFilterImpl>(args); | |||
| } | |||
| size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| @@ -37,6 +37,9 @@ public: | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| @@ -72,6 +75,9 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -105,6 +111,9 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -28,6 +28,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
| return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | |||
| } | |||
| std::vector<MatrixMulForwardImpl::Algorithm*> | |||
| MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args); | |||
| } | |||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -60,6 +60,10 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -33,6 +33,11 @@ PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | |||
| } | |||
| std::vector<PoolingForwardImpl::Algorithm*> | |||
| PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst}); | |||
| } | |||
| PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| @@ -77,6 +82,15 @@ PoolingBackwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| {this, src, dst, diff, grad}); | |||
| } | |||
| std::vector<PoolingBackwardImpl::Algorithm*> | |||
| PoolingBackwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& dst, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>( | |||
| {this, src, dst, diff, grad}); | |||
| } | |||
| PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| @@ -55,6 +55,8 @@ public: | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -99,6 +101,9 @@ protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| @@ -26,6 +26,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||
| } | |||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||
| BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args); | |||
| } | |||
| BatchedMatrixMulForwardImpl::Algorithm* | |||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| @@ -35,6 +35,9 @@ private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| @@ -279,11 +279,18 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms( | |||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | |||
| auto ret = get_all_algorithms_with_ncb(fparam); | |||
| if (ret.empty()) { | |||
| return naive::ConvBiasForwardImpl::get_all_algorithms(src, filter, bias, | |||
| return naive::ConvBiasForwardImpl::get_all_algorithms_safe(src, filter, bias, | |||
| z, dst); | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) { | |||
| auto ret_safe = ConvBiasImpl::get_all_algorithms(src,filter,bias,z,dst); | |||
| return ret_safe; | |||
| } | |||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| @@ -87,6 +87,10 @@ public: | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| //! implemented by get_algorithm_heuristic_with_ncb() | |||
| Algorithm* get_algorithm_heuristic( | |||
| @@ -198,12 +198,19 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms( | |||
| auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | |||
| auto ret = get_all_algorithms_with_ncb(fparam); | |||
| if (ret.empty()) { | |||
| return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter, | |||
| return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter, | |||
| dst); | |||
| } | |||
| return ret; | |||
| } | |||
| std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) { | |||
| auto ret_safe = ConvolutionImpl::get_all_algorithms(src,filter,dst); | |||
| return ret_safe; | |||
| } | |||
| ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||
| @@ -536,10 +543,19 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||
| } | |||
| auto fparam = make_ncb_kern_size_param(filter, diff, grad); | |||
| auto ret = get_all_algorithms_with_ncb(fparam); | |||
| megdnn_assert(!ret.empty(), "no usable conv fwd algorithm"); | |||
| return ret; | |||
| } | |||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||
| ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter,diff,grad); | |||
| megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm"); | |||
| return ret_safe; | |||
| } | |||
| ConvolutionBackwardDataImpl::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| @@ -85,6 +85,10 @@ public: | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| //! implemented by get_algorithm_heuristic_with_ncb() | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| @@ -326,6 +330,9 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -96,6 +96,13 @@ std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | |||
| return gemv_algos; | |||
| } | |||
| std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms_safe( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||
| auto gemv_algos_safe = get_all_algorithms(A,B,C); | |||
| megdnn_assert(!gemv_algos_safe.empty(), "no usable MatrixMul fwd algorithm"); | |||
| return gemv_algos_safe; | |||
| } | |||
| MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| if (!desc.valid()) { | |||
| @@ -270,6 +270,10 @@ protected: | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -128,6 +128,16 @@ BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, | |||
| ->default_batch_conv_bias_fwd_algo()}; | |||
| } | |||
| std::vector<BatchConvBiasForward::Algorithm*> | |||
| BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_batch_conv_bias_fwd_algo()}; | |||
| } | |||
| BatchConvBiasForward::Algorithm* | |||
| BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | |||
| @@ -30,6 +30,11 @@ public: | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| @@ -63,7 +63,6 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||
| } | |||
| } | |||
| std::vector<BatchedMatrixMulForward::Algorithm*> | |||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| @@ -71,6 +70,13 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_batched_matmul_fwd_algo()}; | |||
| } | |||
| std::vector<BatchedMatrixMulForward::Algorithm*> | |||
| BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) { | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_batched_matmul_fwd_algo()}; | |||
| } | |||
| BatchedMatrixMulForward::Algorithm* | |||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| @@ -27,6 +27,9 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| @@ -321,6 +321,15 @@ ConvBiasForwardImpl::get_all_algorithms(const TensorLayout&, | |||
| return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | |||
| } | |||
| std::vector<ConvBiasForward::Algorithm*> | |||
| ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()}; | |||
| } | |||
| ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | |||
| const TensorLayout& /* bias */, const TensorLayout& /* z */, | |||
| @@ -31,6 +31,11 @@ public: | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& bias, const TensorLayout& z, | |||
| @@ -287,6 +287,13 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, | |||
| return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | |||
| } | |||
| std::vector<ConvolutionForward::Algorithm *> | |||
| ConvolutionForwardImpl:: get_all_algorithms_safe(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| { | |||
| return {static_cast<HandleImpl *>(handle())->default_conv_fwd_algo()}; | |||
| } | |||
| ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | |||
| const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | |||
| @@ -313,6 +320,13 @@ ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, | |||
| return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | |||
| } | |||
| std::vector<ConvolutionBackwardData::Algorithm *> | |||
| ConvolutionBackwardDataImpl:: get_all_algorithms_safe(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| { | |||
| return {static_cast<HandleImpl *>(handle())->default_conv_bwd_data_algo()}; | |||
| } | |||
| ConvolutionBackwardData::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | |||
| @@ -341,6 +355,13 @@ ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, | |||
| return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | |||
| } | |||
| std::vector<ConvolutionBackwardFilter::Algorithm *> | |||
| ConvolutionBackwardFilterImpl:: get_all_algorithms_safe(const TensorLayout &, | |||
| const TensorLayout &, const TensorLayout &) | |||
| { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()}; | |||
| } | |||
| ConvolutionBackwardFilter::Algorithm* | |||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||
| @@ -25,6 +25,9 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &filter, | |||
| const TensorLayout &dst) override; | |||
| std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src, | |||
| const TensorLayout &filter, | |||
| const TensorLayout &dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||
| @@ -67,6 +70,9 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &filter, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -90,6 +96,9 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| std::vector<Algorithm *> get_all_algorithms_safe(const TensorLayout &src, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -108,13 +108,18 @@ void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||
| megdnn_assert_internal(0); | |||
| } | |||
| std::vector<Convolution3DForward::Algorithm*> | |||
| Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | |||
| } | |||
| std::vector<Convolution3DForward::Algorithm*> | |||
| Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo()}; | |||
| } | |||
| Convolution3DForward::Algorithm* | |||
| Convolution3DForwardImpl::get_algorithm_heuristic( | |||
| @@ -143,6 +148,13 @@ Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | |||
| } | |||
| std::vector<Convolution3DBackwardData::Algorithm*> | |||
| Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo()}; | |||
| } | |||
| Convolution3DBackwardData::Algorithm* | |||
| Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | |||
| @@ -172,6 +184,14 @@ Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||
| ->default_conv3d_bwd_filter_algo()}; | |||
| } | |||
| std::vector<Convolution3DBackwardFilter::Algorithm*> | |||
| Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_conv3d_bwd_filter_algo()}; | |||
| } | |||
| Convolution3DBackwardFilter::Algorithm* | |||
| Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||
| @@ -22,6 +22,9 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||
| @@ -44,6 +47,9 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -66,6 +72,9 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -25,6 +25,12 @@ public: | |||
| const TensorLayout& /* dst */) override { | |||
| return std::vector<Algorithm*>(); | |||
| }; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||
| const TensorLayout& /* dst */) override { | |||
| return std::vector<Algorithm*>(); | |||
| }; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | |||
| @@ -67,6 +73,13 @@ public: | |||
| const TensorLayout& /* filter_grad */) override { | |||
| return std::vector<Algorithm*>(); | |||
| }; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /* im */, const TensorLayout& /* offset */, | |||
| const TensorLayout& /* mask */, const TensorLayout& /* out_grad */, | |||
| const TensorLayout& /* filter_grad */) override { | |||
| return std::vector<Algorithm*>(); | |||
| }; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /* im */, const TensorLayout& /* offset */, | |||
| @@ -112,6 +125,16 @@ public: | |||
| return std::vector<Algorithm*>(); | |||
| }; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||
| const TensorLayout& /* out_grad */, | |||
| const TensorLayout& /* im_grad */, | |||
| const TensorLayout& /* offset_grad */, | |||
| const TensorLayout& /* mask_grad */) override { | |||
| return std::vector<Algorithm*>(); | |||
| }; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||
| @@ -159,6 +159,13 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, | |||
| return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | |||
| } | |||
| std::vector<LocalShareForward::Algorithm*> | |||
| LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo()}; | |||
| } | |||
| LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | |||
| @@ -187,6 +194,14 @@ LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, | |||
| ->default_local_share_bwd_data_algo()}; | |||
| } | |||
| std::vector<LocalShareBackwardData::Algorithm*> | |||
| LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_local_share_bwd_data_algo()}; | |||
| } | |||
| LocalShareBackwardData::Algorithm* | |||
| LocalShareBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | |||
| @@ -216,6 +231,14 @@ LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, | |||
| ->default_local_share_bwd_filter_algo()}; | |||
| } | |||
| std::vector<LocalShareBackwardFilter::Algorithm*> | |||
| LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_local_share_bwd_filter_algo()}; | |||
| } | |||
| LocalShareBackwardFilter::Algorithm* | |||
| LocalShareBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | |||
| @@ -30,6 +30,10 @@ public: | |||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | |||
| const TensorLayout& /*dst*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | |||
| const TensorLayout& /*dst*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | |||
| const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, | |||
| @@ -55,6 +59,10 @@ public: | |||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | |||
| const TensorLayout& /*grad*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | |||
| const TensorLayout& /*grad*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | |||
| const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | |||
| @@ -75,11 +83,14 @@ public: | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | |||
| const TensorLayout& /*grad*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | |||
| const TensorLayout& /*grad*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | |||
| const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | |||
| @@ -88,6 +88,13 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||
| return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||
| } | |||
| std::vector<MatrixMulForward::Algorithm*> | |||
| MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) { | |||
| return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||
| } | |||
| MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||
| @@ -29,6 +29,10 @@ public: | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||
| @@ -603,6 +603,10 @@ std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms( | |||
| const TensorLayout&, const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | |||
| } | |||
| std::vector<Algorithm*> PoolingForwardImpl::get_all_algorithms_safe( | |||
| const TensorLayout&, const TensorLayout&) { | |||
| return {static_cast<HandleImpl*>(handle())->default_pooling_fwd_algo()}; | |||
| } | |||
| Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | |||
| @@ -626,6 +630,11 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||
| const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | |||
| return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | |||
| } | |||
| std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | |||
| const TensorLayout& /*diff*/, const TensorLayout& /*grad*/) { | |||
| return {static_cast<HandleImpl*>(handle())->default_pooling_bwd_algo()}; | |||
| } | |||
| Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /*src*/, const TensorLayout& /*dst*/, | |||
| @@ -35,6 +35,8 @@ class PoolingForwardImpl: public PoolingForward { | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| @@ -60,6 +62,9 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| @@ -29,6 +29,14 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
| return megdnn::get_all_algorithms<BatchedMatrixMulForwardImpl>(args); | |||
| } | |||
| std::vector<BatchedMatrixMulForwardImpl::Algorithm*> | |||
| BatchedMatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms_safe<BatchedMatrixMulForwardImpl>(args); | |||
| } | |||
| BatchedMatrixMulForwardImpl::Algorithm* | |||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| @@ -35,6 +35,9 @@ private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| @@ -109,6 +109,14 @@ ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| {this, src, filter, dst}); | |||
| } | |||
| std::vector<ConvolutionForwardImpl::Algorithm*> | |||
| ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& filter, | |||
| const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms_safe<ConvolutionForwardImpl>( | |||
| {this, src, filter, dst}); | |||
| } | |||
| size_t ConvolutionForwardImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, const PreprocessedFilter*) { | |||
| @@ -162,6 +170,14 @@ ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, | |||
| {this, filter, diff, grad}); | |||
| } | |||
| std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||
| ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardDataImpl>( | |||
| {this, filter, diff, grad}); | |||
| } | |||
| ConvolutionBackwardDataImpl::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| @@ -243,6 +259,14 @@ ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, | |||
| {this, src, diff, grad}); | |||
| } | |||
| std::vector<ConvolutionBackwardFilterImpl::Algorithm*> | |||
| ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<ConvolutionBackwardFilterImpl>( | |||
| {this, src, diff, grad}); | |||
| } | |||
| ConvolutionBackwardFilterImpl::Algorithm* | |||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| @@ -74,6 +74,9 @@ private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||
| @@ -123,6 +126,9 @@ private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -172,6 +178,9 @@ private: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| @@ -27,6 +27,14 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||
| return megdnn::get_all_algorithms<MatrixMulForwardImpl>(args); | |||
| } | |||
| std::vector<MatrixMulForwardImpl::Algorithm*> | |||
| MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) { | |||
| AlgoBase::SizeArgs args{this, A, B, C}; | |||
| return megdnn::get_all_algorithms_safe<MatrixMulForwardImpl>(args); | |||
| } | |||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -36,6 +36,10 @@ private: | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||
| @@ -25,12 +25,16 @@ size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
| const char* PoolingForwardImpl::get_algorithm_set_name() const { | |||
| return "ROCM_POOLING_FORWARD"; | |||
| } | |||
| std::vector<PoolingForwardImpl::Algorithm*> | |||
| PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, | |||
| const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms<PoolingForwardImpl>({this, src, dst}); | |||
| } | |||
| std::vector<PoolingForwardImpl::Algorithm*> | |||
| PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, | |||
| const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms_safe<PoolingForwardImpl>({this, src, dst}); | |||
| } | |||
| PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| @@ -82,6 +86,13 @@ std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms( | |||
| {this, src, dst, diff, grad}); | |||
| } | |||
| std::vector<Algorithm*> PoolingBackwardImpl::get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad) { | |||
| return megdnn::get_all_algorithms_safe<PoolingBackwardImpl>( | |||
| {this, src, dst, diff, grad}); | |||
| } | |||
| Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| @@ -46,6 +46,8 @@ class PoolingForwardImpl final: public PoolingForward { | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -93,6 +95,9 @@ class PoolingBackwardImpl final: public PoolingBackward { | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| @@ -74,11 +74,14 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
| return fallback_worksapce; | |||
| } | |||
| } | |||
| std::vector<Algorithm*> PoolingImpl::get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms<PoolingImpl>({this, src, dst}); | |||
| } | |||
| std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst) { | |||
| return megdnn::get_all_algorithms_safe<PoolingImpl>({this, src, dst}); | |||
| } | |||
| Algorithm* PoolingImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| @@ -63,6 +63,8 @@ public: | |||
| protected: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& src, const TensorLayout& dst) override; | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||
| @@ -164,7 +164,7 @@ public: | |||
| } | |||
| std::vector<Algorithm::Info::Desc> ret; | |||
| megdnn_assert(layouts.size() == OprTrait<Opr>::arity); | |||
| auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | |||
| auto vec = AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe( | |||
| opr, layouts); | |||
| for (auto algo_info : vec) { | |||
| if (!(algo_info.attribute & | |||
| @@ -377,7 +377,7 @@ float algo_benchmark(Benchmarker<Opr, T>& benchmark, TensorLayoutArray layouts, | |||
| auto opr = benchmark.opr(); | |||
| opr->param() = benchmark.param(); | |||
| proxy.deduce_layout(opr, layouts); | |||
| auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info(opr, layouts); | |||
| auto algos = OprAlgoProxy<Opr>::get_all_algorithms_info_safe(opr, layouts); | |||
| float min_used = std::numeric_limits<float>::max(); | |||
| bool execed = false; | |||
| for (auto i : algos) { | |||
| @@ -514,7 +514,7 @@ struct ExecutionPolicyAlgoName { | |||
| * \brief a callable to check that given algorithm is used for heuristic | |||
| * \param require_algo if its value is true, then requires | |||
| * get_algorithm_heuristic() to return the expected algo; otherwise the | |||
| * expected algo must exist in get_all_algorithms() and it would be set to | |||
| * expected algo must exist in get_all_algorithms_safe() and it would be set to | |||
| * be used | |||
| */ | |||
| template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> | |||
| @@ -536,7 +536,7 @@ public: | |||
| opr->param() = | |||
| Algorithm::deserialize_read_pod<typename Opr::Param>(param); | |||
| for (auto algo_info : | |||
| AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info( | |||
| AlgoProxy<Opr, OprTrait<Opr>::arity>::get_all_algorithms_info_safe( | |||
| opr.get(), layouts)) { | |||
| if (std::regex_match( | |||
| algo_info.desc.name, | |||
| @@ -695,7 +695,7 @@ Checker<Convolution> checker(handle); | |||
| float scale = 1.0f / sqrt(fshp[channel_start] * FH * FW); | |||
| UniformFloatRNG rng(scale, 2 * scale); | |||
| checker.set_rng(0, &rng).set_rng(1, &rng); | |||
| for (auto algo : opr->get_all_algorithms_info(ily, fly, oly)) { | |||
| for (auto algo : opr->get_all_algorithms_info_safe(ily, fly, oly)) { | |||
| used_algos.insert(algo.desc); | |||
| opr->execution_policy().algo = algo.desc; | |||
| @@ -720,7 +720,7 @@ Checker<Convolution> checker(handle); | |||
| opr->param() = param; | |||
| std::string param_str; | |||
| Algorithm::serialize_write_pod(opr->param(), param_str); | |||
| for (auto algo : opr->get_all_algorithms_info(fly, oly, ily)) { | |||
| for (auto algo : opr->get_all_algorithms_info_safe(fly, oly, ily)) { | |||
| used_algos_bwd_data.insert(algo.desc); | |||
| opr->execution_policy().algo = algo.desc; | |||
| construct_sub_execution_policy_heuristic< | |||
| @@ -747,7 +747,7 @@ Checker<Convolution> checker(handle); | |||
| opr->param() = param; | |||
| std::string param_str; | |||
| Algorithm::serialize_write_pod(opr->param(), param_str); | |||
| for (auto algo : opr->get_all_algorithms_info(ily, oly, fly)) { | |||
| for (auto algo : opr->get_all_algorithms_info_safe(ily, oly, fly)) { | |||
| used_algos_bwd_flt.insert(algo.desc); | |||
| opr->execution_policy().algo = algo.desc; | |||
| construct_sub_execution_policy_heuristic< | |||
| @@ -25,9 +25,9 @@ struct AlgoProxy; | |||
| template <typename Opr> \ | |||
| struct AlgoProxy<Opr, arity> { \ | |||
| static std::vector<typename Opr::AlgorithmInfo> \ | |||
| get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ | |||
| get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ | |||
| megdnn_assert(layouts.size() == arity); \ | |||
| return opr->get_all_algorithms_info(LAYOUTS); \ | |||
| return opr->get_all_algorithms_info_safe(LAYOUTS); \ | |||
| } \ | |||
| static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | |||
| Opr* opr, const TensorLayoutArray& layouts) { \ | |||
| @@ -80,9 +80,9 @@ DEF_ALGO_PROXY(8); | |||
| template <> \ | |||
| struct AlgoProxy<Opr, arity> { \ | |||
| static std::vector<typename Opr::AlgorithmInfo> \ | |||
| get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \ | |||
| get_all_algorithms_info_safe(Opr* opr, const TensorLayoutArray& layouts) { \ | |||
| megdnn_assert(layouts.size() == arity); \ | |||
| return opr->get_all_algorithms_info(LAYOUTS); \ | |||
| return opr->get_all_algorithms_info_safe(LAYOUTS); \ | |||
| } \ | |||
| static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \ | |||
| Opr* opr, const TensorLayoutArray& layouts) { \ | |||
| @@ -288,7 +288,7 @@ struct OprProxyProfilingBase | |||
| Algorithm::deserialize_read_pod<typename Opr::Param>(param); | |||
| std::vector<Algorithm::SearchItem> ret; | |||
| for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info( | |||
| for (auto algo_info : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe( | |||
| opr.get(), layouts)) { | |||
| Algorithm* algo = opr->get_algorithm_from_desc(algo_info.desc); | |||
| std::vector<Algorithm::SearchItem>&& sub_items = | |||
| @@ -367,7 +367,7 @@ struct OprProxyProfilingBase | |||
| megdnn_log("Find best algo %s in cache", algo->name()); | |||
| return; | |||
| } | |||
| for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info( | |||
| for (auto algo : AlgoProxy<Opr, arity>::get_all_algorithms_info_safe( | |||
| opr.get(), layouts)) { | |||
| //! construct execution_policy | |||
| opr->execution_policy().algo = algo.desc; | |||
| @@ -492,7 +492,7 @@ struct OprWeightPreprocessProxyImpl : public OprProxyProfilingBase<Opr> { | |||
| if (Base::m_profiling && !Base::target_execution_policy.algo.valid()) { | |||
| size_t min_time = std::numeric_limits<size_t>::max(); | |||
| for (auto algo : | |||
| AlgoProxy<Opr, arity>::get_all_algorithms_info(opr, layouts)) { | |||
| AlgoProxy<Opr, arity>::get_all_algorithms_info_safe(opr, layouts)) { | |||
| opr->execution_policy().algo = algo.desc; | |||
| auto preprocess_tensors = | |||
| @@ -84,7 +84,7 @@ void test_multibatchsize( | |||
| auto opr_reference = handle_cuda->create_operator<MatrixMulForward>(); | |||
| { | |||
| opr_reference->execution_policy().algo.reset(); | |||
| for (auto i : opr_reference->get_all_algorithms_info( | |||
| for (auto i : opr_reference->get_all_algorithms_info_safe( | |||
| A_tensor.layout(), B_tensor.layout(), | |||
| C_tensor.layout())) { | |||
| if (std::regex_match( | |||
| @@ -113,7 +113,7 @@ void test_multibatchsize( | |||
| {{}, {}, C_tensor_prime.tensornd_host()}); | |||
| { | |||
| opr_reference->execution_policy().algo.reset(); | |||
| for (auto i : opr_reference->get_all_algorithms_info( | |||
| for (auto i : opr_reference->get_all_algorithms_info_safe( | |||
| A_tensor_prime.layout(), B_tensor.layout(), | |||
| C_tensor_batch.layout())) { | |||
| if (std::regex_match( | |||
| @@ -1938,7 +1938,7 @@ typename megdnn::ExecutionPolicy try_find_any_weight_preprocess_algo( | |||
| return {}; | |||
| } | |||
| } | |||
| for (auto&& algo : dnn_op->get_all_algorithms_info( | |||
| for (auto&& algo : dnn_op->get_all_algorithms_info_safe( | |||
| std::forward<Args>(args)...)) { | |||
| dnn_op->execution_policy().algo = algo.desc; | |||
| auto layouts = dnn_op->deduce_preprocessed_filter_layout( | |||
| @@ -1972,7 +1972,7 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo( | |||
| return {}; | |||
| } | |||
| } | |||
| for (auto&& algo : dnn_op->get_all_algorithms_info( | |||
| for (auto&& algo : dnn_op->get_all_algorithms_info_safe( | |||
| std::forward<Args>(args)...)) { | |||
| dnn_op->execution_policy().algo = algo.desc; | |||
| auto layouts = dnn_op->deduce_preprocessed_filter_layout( | |||
| @@ -805,7 +805,7 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo> | |||
| AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { | |||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) | |||
| auto heu = choose_by_heuristic(m_execution_policy.strategy); | |||
| auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), | |||
| auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info_safe(args...), | |||
| m_fastrun_layouts); | |||
| bool found = false; | |||
| for (size_t i = 0; i < ret.size(); ++i) { | |||
| @@ -2473,6 +2473,11 @@ public: | |||
| std::vector<AlgorithmInfo>(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2)); | |||
| MOCK_METHOD3(get_all_algorithms_info_safe, | |||
| std::vector<AlgorithmInfo>(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2)); | |||
| MOCK_METHOD6(get_algorithm_info_heuristic, | |||
| AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| @@ -2484,6 +2489,11 @@ public: | |||
| std::vector<Algorithm*>(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2)); | |||
| MOCK_METHOD3(get_all_algorithms_safe, | |||
| std::vector<Algorithm*>(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2)); | |||
| MOCK_METHOD6(get_algorithm_heuristic, | |||
| Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||