GitOrigin-RevId: 88b1ce94a5
tags/v1.3.1
| @@ -165,7 +165,15 @@ public: | |||||
| virtual std::string param() const { return {}; } | virtual std::string param() const { return {}; } | ||||
| virtual uint32_t type() const = 0; | virtual uint32_t type() const = 0; | ||||
| bool contain_attribute(const Attribute& attr) const; | |||||
| //! if algo contain all of the attribute in attr | |||||
| bool contain_attribute_all(const Attribute& attr) const; | |||||
| //! if algo contain any attribute in attr | |||||
| bool contain_attribute_any(const Attribute& attr) const; | |||||
| void check_attribute( | |||||
| const Attribute& positive_attr = Attribute::DEFAULT, | |||||
| const Attribute& negative_attr = Attribute::DEFAULT) const; | |||||
| static std::string attribute_str(const Attribute& attr); | static std::string attribute_str(const Attribute& attr); | ||||
| @@ -342,9 +350,10 @@ public: | |||||
| const TensorLayout& p2, | const TensorLayout& p2, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | ||||
| attr) | |||||
| positive_attr, negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -367,7 +376,8 @@ protected: | |||||
| const TensorLayout& p2, | const TensorLayout& p2, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||||
| }; | }; | ||||
| //! specializae for nargs == 4 | //! specializae for nargs == 4 | ||||
| @@ -402,9 +412,10 @@ public: | |||||
| const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | ||||
| attr) | |||||
| positive_attr, negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -427,7 +438,8 @@ protected: | |||||
| const TensorLayout& p2, const TensorLayout& p3, | const TensorLayout& p2, const TensorLayout& p3, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||||
| }; | }; | ||||
| //! specializae for nargs == 5 | //! specializae for nargs == 5 | ||||
| @@ -464,9 +476,11 @@ public: | |||||
| const TensorLayout& p4, | const TensorLayout& p4, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, | return get_algorithm_heuristic(p0, p1, p2, p3, p4, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -491,7 +505,8 @@ protected: | |||||
| const TensorLayout& p4, | const TensorLayout& p4, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||||
| }; | }; | ||||
| //! specializae for nargs == 8 | //! specializae for nargs == 8 | ||||
| @@ -528,9 +543,11 @@ public: | |||||
| const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) { | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -557,7 +574,8 @@ protected: | |||||
| const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
| size_t workspace_limit_in_bytes = | size_t workspace_limit_in_bytes = | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| const AlgoAttribute& attr = AlgoAttribute::DEFAULT) = 0; | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||||
| }; | }; | ||||
| } // namespace detail | } // namespace detail | ||||
| @@ -27,7 +27,7 @@ inline const char* attr_str(const AlgoAttribute& attr) { | |||||
| return #attr; | return #attr; | ||||
| switch (attr) { FOREACH_ALGO_ATTRIBUTE(cb) } | switch (attr) { FOREACH_ALGO_ATTRIBUTE(cb) } | ||||
| #undef cb | #undef cb | ||||
| return "unknown arch"; | |||||
| return "UNKNOWN"; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -43,11 +43,30 @@ std::string Algorithm::attribute_str(const Attribute& attr) { | |||||
| ret.append(attr_str(sub_attr)); | ret.append(attr_str(sub_attr)); | ||||
| attr_val = attr_val & (attr_val - 1); | attr_val = attr_val & (attr_val - 1); | ||||
| } | } | ||||
| if (ret.empty()) { | |||||
| ret = "DEFAULT"; | |||||
| } | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| bool Algorithm::contain_attribute(const Attribute& attr) const { | |||||
| bool Algorithm::contain_attribute_all(const Attribute& attr) const { | |||||
| return attr == static_cast<Attribute>(attribute() & attr); | return attr == static_cast<Attribute>(attribute() & attr); | ||||
| } | } | ||||
| bool Algorithm::contain_attribute_any(const Attribute& attr) const { | |||||
| return static_cast<bool>(attribute() & attr); | |||||
| } | |||||
| void Algorithm::check_attribute(const Attribute& positive_attr, | |||||
| const Attribute& negative_attr) const { | |||||
| megdnn_assert(contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr), | |||||
| "require algorithm with attribute(%s) and without " | |||||
| "attribute(%s), but get" | |||||
| "algorithm(%s) with attribute(%s) ", | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| Algorithm::attribute_str(negative_attr).c_str(), name(), | |||||
| Algorithm::attribute_str(attribute()).c_str()); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -32,7 +32,7 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||||
| } else { | } else { | ||||
| ret = opr->get_algorithm_info_heuristic( | ret = opr->get_algorithm_info_heuristic( | ||||
| std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | std::forward<Args>(args)..., std::numeric_limits<size_t>::max(), | ||||
| AlgoAttribute::DEFAULT).desc; | |||||
| AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT).desc; | |||||
| } | } | ||||
| return static_cast<typename Opr::AlgoBase*>( | return static_cast<typename Opr::AlgoBase*>( | ||||
| opr->get_algorithm_from_desc(ret)); | opr->get_algorithm_from_desc(ret)); | ||||
| @@ -51,6 +51,7 @@ typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { | |||||
| return static_cast<typename Opr::AlgoBase*>( | return static_cast<typename Opr::AlgoBase*>( | ||||
| opr->get_algorithm_heuristic(std::forward<Args>(args)..., | opr->get_algorithm_heuristic(std::forward<Args>(args)..., | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| AlgoAttribute::DEFAULT, | |||||
| AlgoAttribute::DEFAULT)); | AlgoAttribute::DEFAULT)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -74,34 +75,37 @@ std::vector<typename Opr::Algorithm*> get_all_algorithms( | |||||
| } | } | ||||
| /*! | /*! | ||||
| * \brief a helper function to get an algorithm with attribute. If require a | |||||
| * algorithm with specified attribute, and the given algorithm has that | |||||
| * \brief a helper function to get an algorithm match attribute. If require a | |||||
| * algorithm with specified attribute, and the given algorithm match that | |||||
| * attribute, return the given algorithm. Otherwise return nullptr | * attribute, return the given algorithm. Otherwise return nullptr | ||||
| */ | */ | ||||
| template <typename Opr> | template <typename Opr> | ||||
| typename Opr::Algorithm* get_algo_with_attribute(typename Opr::AlgoBase* algo, | |||||
| const AlgoAttribute& attr) { | |||||
| if (algo->contain_attribute(attr)) { | |||||
| typename Opr::Algorithm* get_algo_match_attribute( | |||||
| typename Opr::AlgoBase* algo, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| if (algo->contain_attribute_all(positive_attr) && | |||||
| !algo->contain_attribute_any(negative_attr)) { | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| template <typename Opr> | template <typename Opr> | ||||
| typename Opr::Algorithm* get_algo_with_attribute( | |||||
| typename Opr::Algorithm* get_algo_match_attribute( | |||||
| const std::vector<typename Opr::AlgoBase*>& algos, | const std::vector<typename Opr::AlgoBase*>& algos, | ||||
| const typename Opr::AlgoBase::SizeArgs& args, | const typename Opr::AlgoBase::SizeArgs& args, | ||||
| size_t workspace_limit_in_bytes, const char* name, | size_t workspace_limit_in_bytes, const char* name, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) { | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); | size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); | ||||
| bool available_but_limited_by_workspace = false; | bool available_but_limited_by_workspace = false; | ||||
| bool available_but_without_attribute = false; | |||||
| bool available_but_attribute_mismatch = false; | |||||
| for (auto i : algos) { | for (auto i : algos) { | ||||
| if (i->is_available_attribute(args, attr, | |||||
| if (i->is_available_attribute(args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) { | workspace_limit_in_bytes)) { | ||||
| return i; | return i; | ||||
| } | } | ||||
| if (i->is_available_attribute(args)) { | |||||
| if (i->is_available_attribute(args, positive_attr, negative_attr)) { | |||||
| if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { | if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { | ||||
| available_but_limited_by_workspace = true; | available_but_limited_by_workspace = true; | ||||
| min_workspace_limit_in_bytes = | min_workspace_limit_in_bytes = | ||||
| @@ -110,53 +114,27 @@ typename Opr::Algorithm* get_algo_with_attribute( | |||||
| } | } | ||||
| } | } | ||||
| if (i->is_available(args)) { | if (i->is_available(args)) { | ||||
| if (!i->contain_attribute(attr)) | |||||
| available_but_without_attribute = true; | |||||
| if (!(i->contain_attribute_all(positive_attr) && | |||||
| !i->contain_attribute_any(negative_attr))) | |||||
| available_but_attribute_mismatch = true; | |||||
| } | } | ||||
| } | } | ||||
| MEGDNN_MARK_USED_VAR(name); | MEGDNN_MARK_USED_VAR(name); | ||||
| if (available_but_limited_by_workspace) { | if (available_but_limited_by_workspace) { | ||||
| megdnn_throw( | |||||
| ssprintf("no %s algorithm without attribute(%s) with " | |||||
| "attribute(%s) : %s workspace limit %zu is " | |||||
| "less than mini workspace limit %zu", | |||||
| name, Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes, | |||||
| min_workspace_limit_in_bytes)); | |||||
| } else if (available_but_attribute_mismatch) { | |||||
| megdnn_throw(ssprintf( | megdnn_throw(ssprintf( | ||||
| "no %s algorithm with attribute:%s : %s workspace limit %zu is " | |||||
| "less than mini workspace limit %zu", | |||||
| name, Algorithm::attribute_str(attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes, | |||||
| min_workspace_limit_in_bytes)); | |||||
| } else if (available_but_without_attribute) { | |||||
| megdnn_throw(ssprintf("no %s algorithm with attribute:%s", name, | |||||
| Algorithm::attribute_str(attr).c_str())); | |||||
| } else { | |||||
| megdnn_throw(ssprintf("no usable %s algorithm", name)); | |||||
| } | |||||
| } | |||||
| template <typename Opr> | |||||
| typename Opr::Algorithm* get_usable_algo( | |||||
| const std::vector<typename Opr::AlgoBase*>& algos, | |||||
| const typename Opr::AlgoBase::SizeArgs& args, | |||||
| size_t workspace_limit_in_bytes, const char* name) { | |||||
| size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); | |||||
| bool available_but_limited_by_workspace = false; | |||||
| for (auto i : algos) { | |||||
| if (i->is_available_wk(args, workspace_limit_in_bytes)) { | |||||
| return i; | |||||
| } | |||||
| if (i->is_available(args)) { | |||||
| available_but_limited_by_workspace = true; | |||||
| min_workspace_limit_in_bytes = | |||||
| std::min(min_workspace_limit_in_bytes, | |||||
| i->get_workspace_in_bytes(args)); | |||||
| } | |||||
| } | |||||
| MEGDNN_MARK_USED_VAR(name); | |||||
| if (available_but_limited_by_workspace) { | |||||
| megdnn_throw(ssprintf( | |||||
| "no usable %s algorithm: %s workspace limit %zu is " | |||||
| "less than mini workspace limit %zu", | |||||
| name, args.to_string().c_str(), workspace_limit_in_bytes, | |||||
| min_workspace_limit_in_bytes)); | |||||
| "no %s algorithm without attribute(%s) with attribute(%s)", name, | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str())); | |||||
| } else { | } else { | ||||
| megdnn_throw(ssprintf("no usable %s algorithm", name)); | megdnn_throw(ssprintf("no usable %s algorithm", name)); | ||||
| } | } | ||||
| @@ -67,9 +67,12 @@ public: | |||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -22,21 +22,24 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, filter, bias, z, dst); | AlgoBase::SizeArgs args(this, src, filter, bias, z, dst); | ||||
| if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_attribute( | if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.int8_nchw4_gemm_dotprod; | return &sm_algo_pack.int8_nchw4_gemm_dotprod; | ||||
| } | } | ||||
| if (sm_algo_pack.int8_nchw4_implicit_gemm_dotprod.is_available_attribute( | if (sm_algo_pack.int8_nchw4_implicit_gemm_dotprod.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod; | return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod; | ||||
| } | } | ||||
| megdnn_throw(ssprintf( | |||||
| "no batch conv bias algorithm with attribute%s args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
| workspace_limit_in_bytes)); | |||||
| megdnn_throw( | |||||
| ssprintf("no batch conv bias algorithm without attribute(%s) with " | |||||
| "attribute(%s) args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
| } | } | ||||
| std::vector<BatchConvBiasForwardImpl::Algorithm*> | std::vector<BatchConvBiasForwardImpl::Algorithm*> | ||||
| @@ -42,13 +42,12 @@ protected: | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -70,9 +70,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -55,26 +55,37 @@ std::vector<Algorithm*> BatchedMatrixMulForwardImpl::get_all_algorithms( | |||||
| Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | ||||
| AlgoBase::SizeArgs args(this, A, B, C); | AlgoBase::SizeArgs args(this, A, B, C); | ||||
| if (sm_algo_pack.cublas.is_available_attribute(args, attr)) { | |||||
| if (sm_algo_pack.cublas.is_available_attribute(args, positive_attr, | |||||
| negative_attr)) { | |||||
| return &sm_algo_pack.cublas; | return &sm_algo_pack.cublas; | ||||
| } | } | ||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| else if (sm_algo_pack.cublasLt.is_available_attribute(args, attr)) { | |||||
| else if (sm_algo_pack.cublasLt.is_available_attribute(args, positive_attr, | |||||
| negative_attr)) { | |||||
| return &sm_algo_pack.cublasLt; | return &sm_algo_pack.cublasLt; | ||||
| } | } | ||||
| #endif | #endif | ||||
| else if (sm_algo_pack.int8x8x32.is_available_attribute(args, attr)) { | |||||
| else if (sm_algo_pack.int8x8x32.is_available_attribute(args, positive_attr, | |||||
| negative_attr)) { | |||||
| return &sm_algo_pack.int8x8x32; | return &sm_algo_pack.int8x8x32; | ||||
| } else { | } else { | ||||
| if (sm_algo_pack.brute_force.is_available_attribute(args, attr)) { | |||||
| if (sm_algo_pack.brute_force.is_available_attribute(args, positive_attr, | |||||
| negative_attr)) { | |||||
| return &sm_algo_pack.brute_force; | return &sm_algo_pack.brute_force; | ||||
| } | } | ||||
| } | } | ||||
| megdnn_throw("No usable algo for batched_matrix_mul"); | |||||
| megdnn_throw(ssprintf( | |||||
| "no batched_matrix_mul algorithm without attribute(%s) with " | |||||
| "attribute(%s) args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
| return nullptr; | return nullptr; | ||||
| }; | }; | ||||
| @@ -45,11 +45,10 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -129,9 +129,12 @@ public: | |||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -426,7 +429,7 @@ public: | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| auto ret = static_cast<AlgoAttribute>(0); | auto ret = static_cast<AlgoAttribute>(0); | ||||
| if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||||
| ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -51,7 +51,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | ||||
| auto dst_layout = *args.dst_layout; | auto dst_layout = *args.dst_layout; | ||||
| @@ -74,7 +75,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| }; | }; | ||||
| auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
| [this, &conv_args, &args, workspace_limit_in_bytes, attr]( | |||||
| [this, &conv_args, &args, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr]( | |||||
| const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>& | const thin_function<AlgoBase*(cudnnConvolutionFwdAlgo_t)>& | ||||
| cb) -> AlgoBase* { | cb) -> AlgoBase* { | ||||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
| @@ -93,7 +95,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| for (int i = 0; i < ret_count; ++i) { | for (int i = 0; i < ret_count; ++i) { | ||||
| auto conv_bias_algo = cb(algo_perf[i].algo); | auto conv_bias_algo = cb(algo_perf[i].algo); | ||||
| if (conv_bias_algo->is_available_attribute( | if (conv_bias_algo->is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) | |||||
| args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) | |||||
| return conv_bias_algo; | return conv_bias_algo; | ||||
| } | } | ||||
| #else | #else | ||||
| @@ -105,18 +108,20 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| workspace_limit_in_bytes, &algo)); | workspace_limit_in_bytes, &algo)); | ||||
| auto conv_bias_algo = cb(algo); | auto conv_bias_algo = cb(algo); | ||||
| if (conv_bias_algo->is_available_attribute(args, attr, | |||||
| if (conv_bias_algo->is_available_attribute(args, positive_attr, | |||||
| negative_attr, | |||||
| workspace_limit_in_bytes)) | workspace_limit_in_bytes)) | ||||
| return conv_bias_algo; | return conv_bias_algo; | ||||
| #endif | #endif | ||||
| return nullptr; | return nullptr; | ||||
| }; | }; | ||||
| auto get_1x1_algo = [workspace_limit_in_bytes, | |||||
| attr](const AlgoBase::SizeArgs& size_arg) | |||||
| auto get_1x1_algo = [workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr](const AlgoBase::SizeArgs& size_arg) | |||||
| -> ConvBiasForwardImpl::AlgoBase* { | -> ConvBiasForwardImpl::AlgoBase* { | ||||
| if (sm_algo_pack.batched_matmul.is_available_attribute( | if (sm_algo_pack.batched_matmul.is_available_attribute( | ||||
| size_arg, attr, workspace_limit_in_bytes)) { | |||||
| size_arg, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| @@ -145,10 +150,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| if (is_chanwise) { | if (is_chanwise) { | ||||
| if (prefer_dnn_chanwise) { | if (prefer_dnn_chanwise) { | ||||
| if (sm_algo_pack.chanwise.is_available_attribute( | if (sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) | |||||
| args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) | |||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| if (sm_algo_pack.chanwise8x8x32.is_available_attribute( | if (sm_algo_pack.chanwise8x8x32.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) | |||||
| args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) | |||||
| return &sm_algo_pack.chanwise8x8x32; | return &sm_algo_pack.chanwise8x8x32; | ||||
| } else { | } else { | ||||
| conv_args.dst_layout = &dst_layout; | conv_args.dst_layout = &dst_layout; | ||||
| @@ -163,7 +170,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| //! Prefer CUDNN CONVBIAS. | //! Prefer CUDNN CONVBIAS. | ||||
| bool cudnn_conv_bias_act_supported = false; | bool cudnn_conv_bias_act_supported = false; | ||||
| for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) { | for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) { | ||||
| if (algo.is_available_attribute(args, attr, workspace_limit_in_bytes)) { | |||||
| if (algo.is_available_attribute(args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) { | |||||
| cudnn_conv_bias_act_supported = true; | cudnn_conv_bias_act_supported = true; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -201,30 +209,18 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| } | } | ||||
| if (sm_algo_pack.fallback_nchw_qs8.is_available_attribute( | if (sm_algo_pack.fallback_nchw_qs8.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.fallback_nchw_qs8; | return &sm_algo_pack.fallback_nchw_qs8; | ||||
| } | } | ||||
| if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, | |||||
| workspace_limit_in_bytes, "cuda convbias fwd", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, | |||||
| workspace_limit_in_bytes, "cuda convbias fwd"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda convbias fwd", positive_attr, negative_attr); | |||||
| } else { | } else { | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvBiasForwardImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda convbias fwd", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda convbias fwd"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvBiasForwardImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda convbias fwd", positive_attr, negative_attr); | |||||
| } | } | ||||
| } | } | ||||
| @@ -76,13 +76,12 @@ public: | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -84,9 +84,12 @@ public: | |||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -229,7 +232,7 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| auto ret = static_cast<AlgoAttribute>(0); | auto ret = static_cast<AlgoAttribute>(0); | ||||
| if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||||
| ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -80,9 +80,12 @@ public: | |||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -214,7 +217,7 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| auto ret = static_cast<AlgoAttribute>(0); | auto ret = static_cast<AlgoAttribute>(0); | ||||
| if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||||
| ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -65,9 +65,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -33,14 +33,15 @@ using namespace convolution; | |||||
| /* ============== ConvolutionForwardImpl ============== */ | /* ============== ConvolutionForwardImpl ============== */ | ||||
| ConvolutionForwardImpl::Algorithm* | ConvolutionForwardImpl::Algorithm* | ||||
| ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) { | |||||
| ConvolutionForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args{this, src, filter, dst}; | AlgoBase::SizeArgs args{this, src, filter, dst}; | ||||
| MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); | ||||
| MEGDNN_MARK_USED_VAR(attr); | |||||
| MEGDNN_MARK_USED_VAR(positive_attr); | |||||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||||
| return &sm_algo_pack.algo_default; | return &sm_algo_pack.algo_default; | ||||
| } | } | ||||
| @@ -101,46 +102,45 @@ ConvolutionBackwardDataImpl::Algorithm* | |||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(grad, filter, diff); | auto fm = check_layout_fwd(grad, filter, diff); | ||||
| return get_algorithm_heuristic(filter, fm, diff, grad, | return get_algorithm_heuristic(filter, fm, diff, grad, | ||||
| workspace_limit_in_bytes, attr); | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr); | |||||
| } | } | ||||
| ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | ||||
| const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, | const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); | AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); | ||||
| if (args.filter_meta.group > 1 && | if (args.filter_meta.group > 1 && | ||||
| sm_algo_pack.chanwise.is_available_attribute( | sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| // prefer special chanwise impl | // prefer special chanwise impl | ||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| if (args.filter_layout->dtype.enumv() == | if (args.filter_layout->dtype.enumv() == | ||||
| DTypeTrait<dtype::QuantizedS8>::enumv) { | DTypeTrait<dtype::QuantizedS8>::enumv) { | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_data", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_data"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_data", positive_attr, negative_attr); | |||||
| } | } | ||||
| auto get_cudnn_algo = [this, &args, workspace_limit_in_bytes, | |||||
| attr]() -> ConvolutionBackwardDataImpl::AlgoBase* { | |||||
| auto get_cudnn_algo = | |||||
| [this, &args, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr]() -> ConvolutionBackwardDataImpl::AlgoBase* { | |||||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
| CUDNNBwdDataDescs desc; | CUDNNBwdDataDescs desc; | ||||
| args.init_desc(desc); | args.init_desc(desc); | ||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||||
| int max_count = 0; | int max_count = 0; | ||||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | ||||
| cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
| @@ -153,7 +153,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||||
| for (int i = 0; i < ret_count; ++i) { | for (int i = 0; i < ret_count; ++i) { | ||||
| if (algo_perf[i].memory > workspace_limit_in_bytes) | if (algo_perf[i].memory > workspace_limit_in_bytes) | ||||
| continue; | continue; | ||||
| if (attr & AlgoAttribute::REPRODUCIBLE) { | |||||
| if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | ||||
| return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
| sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | ||||
| @@ -174,8 +174,8 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||||
| auto&& cast_algo = | auto&& cast_algo = | ||||
| reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | ||||
| return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
| megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
| cast_algo, attr)); | |||||
| megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>( | |||||
| cast_algo, positive_attr, negative_attr)); | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -197,25 +197,13 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||||
| if (args.filter_layout->dtype.enumv() != | if (args.filter_layout->dtype.enumv() != | ||||
| DTypeTrait<dtype::BFloat16>::enumv) { | DTypeTrait<dtype::BFloat16>::enumv) { | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, | |||||
| workspace_limit_in_bytes, "cuda conv bwd_data", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, | |||||
| workspace_limit_in_bytes, "cuda conv bwd_data"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_data", positive_attr, negative_attr); | |||||
| } else { | } else { | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_data", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_data"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_data", positive_attr, negative_attr); | |||||
| } | } | ||||
| } | } | ||||
| @@ -255,29 +243,33 @@ ConvolutionBackwardFilterImpl::Algorithm* | |||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(src, grad, diff); | auto fm = check_layout_fwd(src, grad, diff); | ||||
| return get_algorithm_heuristic(src, diff, grad, fm, | return get_algorithm_heuristic(src, diff, grad, fm, | ||||
| workspace_limit_in_bytes, attr); | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr); | |||||
| } | } | ||||
| ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); | AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); | ||||
| if (args.grad_filter_meta.group > 1 && | if (args.grad_filter_meta.group > 1 && | ||||
| sm_algo_pack.chanwise.is_available_attribute( | sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| // prefer special chanwise impl | // prefer special chanwise impl | ||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
| [this, &args, workspace_limit_in_bytes, | |||||
| attr]() -> ConvolutionBackwardFilterImpl::AlgoBase* { | |||||
| [this, &args, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr]() -> ConvolutionBackwardFilterImpl::AlgoBase* { | |||||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
| CUDNNBwdFilterDescs desc; | CUDNNBwdFilterDescs desc; | ||||
| args.init_desc(desc); | args.init_desc(desc); | ||||
| @@ -293,6 +285,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
| } | } | ||||
| #endif | #endif | ||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||||
| int max_count = 0; | int max_count = 0; | ||||
| cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | ||||
| cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
| @@ -305,7 +298,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
| for (int i = 0; i < ret_count; ++i) { | for (int i = 0; i < ret_count; ++i) { | ||||
| if (algo_perf[i].memory > workspace_limit_in_bytes) | if (algo_perf[i].memory > workspace_limit_in_bytes) | ||||
| continue; | continue; | ||||
| if (attr & AlgoAttribute::REPRODUCIBLE) { | |||||
| if ((positive_attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { | ||||
| return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
| sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); | ||||
| @@ -326,8 +319,8 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
| auto&& cast_algo = | auto&& cast_algo = | ||||
| reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | reinterpret_cast<AlgoBase*>(sm_algo_pack.cudnn_from_enum(algo)); | ||||
| return reinterpret_cast<AlgoBase*>( | return reinterpret_cast<AlgoBase*>( | ||||
| megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>( | |||||
| cast_algo, attr)); | |||||
| megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>( | |||||
| cast_algo, positive_attr, negative_attr)); | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -348,27 +341,13 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
| } | } | ||||
| if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute< | |||||
| ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, | |||||
| workspace_limit_in_bytes, "cuda conv bwd_filter", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, | |||||
| workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_filter", positive_attr, negative_attr); | |||||
| } else { | } else { | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute< | |||||
| ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_filter", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_filter"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv bwd_filter", positive_attr, negative_attr); | |||||
| } | } | ||||
| } | } | ||||
| @@ -59,11 +59,11 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -77,19 +77,22 @@ public: | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, | ||||
| const TensorLayout& diff, const TensorLayout& grad, | const TensorLayout& diff, const TensorLayout& grad, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(filter, filter_meta, diff, grad, | return get_algorithm_heuristic(filter, filter_meta, diff, grad, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) { | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -118,11 +121,11 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | ||||
| @@ -130,7 +133,8 @@ private: | |||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad, | const TensorLayout& grad, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -146,19 +150,22 @@ public: | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, diff, grad, grad_meta, | return get_algorithm_heuristic(src, diff, grad, grad_meta, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) { | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -181,11 +188,11 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
| @@ -193,7 +200,8 @@ private: | |||||
| const TensorLayout& grad, | const TensorLayout& grad, | ||||
| const CanonizedFilterMeta& grad_meta, | const CanonizedFilterMeta& grad_meta, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -77,9 +77,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -164,7 +167,7 @@ public: | |||||
| TensorLayout& grad_pg); | TensorLayout& grad_pg); | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| auto ret = static_cast<AlgoAttribute>(0); | auto ret = static_cast<AlgoAttribute>(0); | ||||
| if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||||
| ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -71,9 +71,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -170,7 +173,7 @@ public: | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| auto ret = static_cast<AlgoAttribute>(0); | auto ret = static_cast<AlgoAttribute>(0); | ||||
| if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||||
| ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -76,9 +76,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -124,7 +127,7 @@ public: | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| auto ret = static_cast<AlgoAttribute>(0); | auto ret = static_cast<AlgoAttribute>(0); | ||||
| if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | |||||
| ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -97,8 +97,10 @@ namespace convolution3d { | |||||
| const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
| const cudnnTensorDescriptor_t y_desc, | const cudnnTensorDescriptor_t y_desc, | ||||
| size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo, | size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo, | ||||
| const AlgoAttribute& attr) { | |||||
| MEGDNN_MARK_USED_VAR(attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| MEGDNN_MARK_USED_VAR(positive_attr); | |||||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| int algo_max_count = 0; | int algo_max_count = 0; | ||||
| cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | ||||
| @@ -118,7 +120,7 @@ namespace convolution3d { | |||||
| cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | ||||
| algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
| if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
| if (!(attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| *algo = algo_perf[i].algo; | *algo = algo_perf[i].algo; | ||||
| return true; | return true; | ||||
| } else { | } else { | ||||
| @@ -144,8 +146,11 @@ namespace convolution3d { | |||||
| const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
| const cudnnTensorDescriptor_t dx_desc, | const cudnnTensorDescriptor_t dx_desc, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| cudnnConvolutionBwdDataAlgo_t* algo, const AlgoAttribute& attr) { | |||||
| MEGDNN_MARK_USED_VAR(attr); | |||||
| cudnnConvolutionBwdDataAlgo_t* algo, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| MEGDNN_MARK_USED_VAR(positive_attr); | |||||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| int algo_max_count = 0; | int algo_max_count = 0; | ||||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | ||||
| @@ -166,7 +171,7 @@ namespace convolution3d { | |||||
| cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, | cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, | ||||
| algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
| if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
| if (!(attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| *algo = algo_perf[i].algo; | *algo = algo_perf[i].algo; | ||||
| return true; | return true; | ||||
| } else { | } else { | ||||
| @@ -193,8 +198,11 @@ namespace convolution3d { | |||||
| const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
| const cudnnFilterDescriptor_t dw_desc, | const cudnnFilterDescriptor_t dw_desc, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| cudnnConvolutionBwdFilterAlgo_t* algo, const AlgoAttribute& attr) { | |||||
| MEGDNN_MARK_USED_VAR(attr); | |||||
| cudnnConvolutionBwdFilterAlgo_t* algo, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| MEGDNN_MARK_USED_VAR(positive_attr); | |||||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| int algo_max_count = 0; | int algo_max_count = 0; | ||||
| cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | ||||
| @@ -215,7 +223,7 @@ namespace convolution3d { | |||||
| cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, | cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, | ||||
| algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
| if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
| if (!(attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | |||||
| *algo = algo_perf[i].algo; | *algo = algo_perf[i].algo; | ||||
| return true; | return true; | ||||
| } else { | } else { | ||||
| @@ -235,7 +243,6 @@ namespace convolution3d { | |||||
| #endif | #endif | ||||
| } | } | ||||
| } // namespace convolution3d | } // namespace convolution3d | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -33,16 +33,18 @@ Convolution3DForwardImpl::Algorithm* | |||||
| Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(src, filter, dst); | auto fm = check_layout_fwd(src, filter, dst); | ||||
| return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| Convolution3DForwardImpl::Algorithm* | Convolution3DForwardImpl::Algorithm* | ||||
| Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const CanonizedFilterMeta& filter, | const TensorLayout& src, const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, filter, dst); | AlgoBase::SizeArgs args(this, src, filter, dst); | ||||
| #if CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5) | #if CUDNN_MAJOR < 7 || (CUDNN_MAJOR == 7 && CUDNN_MINOR < 5) | ||||
| @@ -51,25 +53,27 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
| // version is lower than v7.5.0 is still slower than our implementation | // version is lower than v7.5.0 is still slower than our implementation | ||||
| // in many channel-wise cases | // in many channel-wise cases | ||||
| if (sm_algo_pack.chanwise.is_available_attribute( | if (sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| auto prefer_1x1x1 = [&args, attr, workspace_limit_in_bytes]() { | |||||
| auto prefer_1x1x1 = [&args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes]() { | |||||
| const size_t MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO = 4; | const size_t MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO = 4; | ||||
| size_t batch_size = args.src_layout->shape[0]; | size_t batch_size = args.src_layout->shape[0]; | ||||
| if (batch_size > MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO) { | if (batch_size > MAX_BATCH_SIZE_FOR_1x1x1_MAT_ALGO) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| return sm_algo_pack.a1x1x1.is_available_attribute( | return sm_algo_pack.a1x1x1.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes); | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes); | |||||
| }; | }; | ||||
| auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
| [this, &args, workspace_limit_in_bytes, | |||||
| attr]() -> Convolution3DForwardImpl::AlgoBase* { | |||||
| [this, &args, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | |||||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
| cudnnConvolutionFwdAlgo_t algo; | cudnnConvolutionFwdAlgo_t algo; | ||||
| CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
| @@ -78,11 +82,12 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
| bool got = cudnn_get_convolution_fwd_algo_helper( | bool got = cudnn_get_convolution_fwd_algo_helper( | ||||
| cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | ||||
| desc.conv_desc.desc, desc.dst_desc.desc, | desc.conv_desc.desc, desc.dst_desc.desc, | ||||
| workspace_limit_in_bytes, &algo, attr); | |||||
| workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | |||||
| if (got) { | if (got) { | ||||
| return static_cast<AlgoBase*>( | return static_cast<AlgoBase*>( | ||||
| megdnn::get_algo_with_attribute<Convolution3DForwardImpl>( | |||||
| sm_algo_pack.cudnn_from_enum(algo), attr)); | |||||
| megdnn::get_algo_match_attribute<Convolution3DForwardImpl>( | |||||
| sm_algo_pack.cudnn_from_enum(algo), positive_attr, | |||||
| negative_attr)); | |||||
| } else { | } else { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -108,15 +113,9 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
| args = orig_args; | args = orig_args; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<Convolution3DForwardImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d fwd", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<Convolution3DForwardImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d fwd"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<Convolution3DForwardImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d fwd", positive_attr, negative_attr); | |||||
| } | } | ||||
| std::vector<Convolution3DForwardImpl::Algorithm*> | std::vector<Convolution3DForwardImpl::Algorithm*> | ||||
| @@ -169,28 +168,30 @@ Convolution3DBackwardDataImpl::Algorithm* | |||||
| Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(grad, filter, diff); | auto fm = check_layout_fwd(grad, filter, diff); | ||||
| return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| Convolution3DBackwardDataImpl::Algorithm* | Convolution3DBackwardDataImpl::Algorithm* | ||||
| Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
| if (args.filter_meta.group > 1 && | if (args.filter_meta.group > 1 && | ||||
| sm_algo_pack.chanwise.is_available_attribute( | sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
| [this, &args, workspace_limit_in_bytes, | |||||
| attr]() -> Convolution3DBackwardDataImpl::AlgoBase* { | |||||
| [this, &args, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr]() -> Convolution3DBackwardDataImpl::AlgoBase* { | |||||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
| cudnnConvolutionBwdDataAlgo_t algo; | cudnnConvolutionBwdDataAlgo_t algo; | ||||
| CUDNNBwdDataDescs desc; | CUDNNBwdDataDescs desc; | ||||
| @@ -198,11 +199,12 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||||
| bool got = cudnn_get_convolution_bwd_data_algo_helper( | bool got = cudnn_get_convolution_bwd_data_algo_helper( | ||||
| cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | ||||
| desc.conv_desc.desc, desc.grad_desc.desc, | desc.conv_desc.desc, desc.grad_desc.desc, | ||||
| workspace_limit_in_bytes, &algo, attr); | |||||
| workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | |||||
| if (got) { | if (got) { | ||||
| return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute< | |||||
| return static_cast<AlgoBase*>(megdnn::get_algo_match_attribute< | |||||
| Convolution3DBackwardDataImpl>( | Convolution3DBackwardDataImpl>( | ||||
| sm_algo_pack.cudnn_from_enum(algo), attr)); | |||||
| sm_algo_pack.cudnn_from_enum(algo), positive_attr, | |||||
| negative_attr)); | |||||
| } else { | } else { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -224,15 +226,9 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||||
| args = orig_args; | args = orig_args; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<Convolution3DBackwardDataImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d bwd data", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<Convolution3DBackwardDataImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d bwd data"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<Convolution3DBackwardDataImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d bwd data", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t Convolution3DBackwardDataImpl::get_workspace_in_bytes( | size_t Convolution3DBackwardDataImpl::get_workspace_in_bytes( | ||||
| @@ -269,28 +265,30 @@ Convolution3DBackwardFilterImpl::Algorithm* | |||||
| Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(src, grad, diff); | auto fm = check_layout_fwd(src, grad, diff); | ||||
| return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| Convolution3DBackwardFilterImpl::Algorithm* | Convolution3DBackwardFilterImpl::Algorithm* | ||||
| Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
| if (args.grad_filter_meta.group > 1 && | if (args.grad_filter_meta.group > 1 && | ||||
| sm_algo_pack.chanwise.is_available_attribute( | sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
| [this, &args, workspace_limit_in_bytes, | |||||
| attr]() -> Convolution3DBackwardFilterImpl::AlgoBase* { | |||||
| [this, &args, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr]() -> Convolution3DBackwardFilterImpl::AlgoBase* { | |||||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | auto cudnn_handle = cuda::cudnn_handle(this->handle()); | ||||
| cudnnConvolutionBwdFilterAlgo_t algo; | cudnnConvolutionBwdFilterAlgo_t algo; | ||||
| CUDNNBwdFilterDescs desc; | CUDNNBwdFilterDescs desc; | ||||
| @@ -298,11 +296,12 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
| bool got = cudnn_get_convolution_bwd_filter_algo_helper( | bool got = cudnn_get_convolution_bwd_filter_algo_helper( | ||||
| cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | ||||
| desc.conv_desc.desc, desc.grad_desc.desc, | desc.conv_desc.desc, desc.grad_desc.desc, | ||||
| workspace_limit_in_bytes, &algo, attr); | |||||
| workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | |||||
| if (got) { | if (got) { | ||||
| return static_cast<AlgoBase*>(megdnn::get_algo_with_attribute< | |||||
| return static_cast<AlgoBase*>(megdnn::get_algo_match_attribute< | |||||
| Convolution3DBackwardFilterImpl>( | Convolution3DBackwardFilterImpl>( | ||||
| sm_algo_pack.cudnn_from_enum(algo), attr)); | |||||
| sm_algo_pack.cudnn_from_enum(algo), positive_attr, | |||||
| negative_attr)); | |||||
| } else { | } else { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -323,15 +322,9 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
| args = orig_args; | args = orig_args; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<Convolution3DBackwardFilterImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d bwd filter", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<Convolution3DBackwardFilterImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d bwd filter"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<Convolution3DBackwardFilterImpl>( | |||||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||||
| "cuda conv3d bwd filter", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t Convolution3DBackwardFilterImpl::get_workspace_in_bytes( | size_t Convolution3DBackwardFilterImpl::get_workspace_in_bytes( | ||||
| @@ -25,9 +25,11 @@ public: | |||||
| const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, filter, dst, | return get_algorithm_heuristic(src, filter, dst, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
| @@ -48,19 +50,19 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
| const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -73,9 +75,11 @@ public: | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | size_t get_workspace_in_bytes(const TensorLayout& filter, | ||||
| @@ -98,18 +102,19 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad, | const TensorLayout& grad, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -122,13 +127,14 @@ public: | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const CanonizedFilterMeta& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) { | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, diff, grad, | return get_algorithm_heuristic(src, diff, grad, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -149,18 +155,19 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const CanonizedFilterMeta& grad, | const CanonizedFilterMeta& grad, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -82,9 +82,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -75,9 +75,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -70,9 +70,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -59,10 +59,12 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& mask, | const TensorLayout& mask, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | ||||
| return get_algorithm_heuristic(im, fm, offset, mask, dst, | return get_algorithm_heuristic(im, fm, offset, mask, dst, | ||||
| workspace_limit_in_bytes, attr); | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr); | |||||
| } | } | ||||
| AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | ||||
| @@ -71,17 +73,20 @@ AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& mask, | const TensorLayout& mask, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst); | AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst); | ||||
| if (sm_algo_pack.algo_matmul.is_available_attribute( | if (sm_algo_pack.algo_matmul.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.algo_matmul; | return &sm_algo_pack.algo_matmul; | ||||
| } | } | ||||
| megdnn_throw(ssprintf( | |||||
| "no deformable conv fwd algorithm with attribute%s , args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
| workspace_limit_in_bytes)); | |||||
| megdnn_throw( | |||||
| ssprintf("no deformable conv fwd algorithm without attribute(%s) " | |||||
| "with attribute(%s) , args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
| } | } | ||||
| const char* Fwd::get_algorithm_set_name() const { | const char* Fwd::get_algorithm_set_name() const { | ||||
| @@ -114,28 +119,33 @@ std::vector<AlgoBwdFlt*> BwdFlt::get_all_algorithms(const TensorLayout& /* im */ | |||||
| AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | ||||
| const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
| const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
| const TensorLayout& filter_grad, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| const TensorLayout& filter_grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset); | auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset); | ||||
| return get_algorithm_heuristic(im, offset, mask, out_grad, fm, | return get_algorithm_heuristic(im, offset, mask, out_grad, fm, | ||||
| workspace_limit_in_bytes, attr); | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr); | |||||
| } | } | ||||
| AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( | ||||
| const TensorLayout& im, const TensorLayout& offset, | const TensorLayout& im, const TensorLayout& offset, | ||||
| const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
| const CanonizedFilterMeta& filter_grad, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| const CanonizedFilterMeta& filter_grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad); | AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad); | ||||
| if (sm_algo_pack.algo_matmul.is_available_attribute( | if (sm_algo_pack.algo_matmul.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.algo_matmul; | return &sm_algo_pack.algo_matmul; | ||||
| } | } | ||||
| megdnn_throw( | megdnn_throw( | ||||
| ssprintf("no deformable conv bwd filter algorithm with " | |||||
| "attribute%s, args(%s) and " | |||||
| ssprintf("no deformable conv bwd filter algorithm without " | |||||
| "attribute(%s) with " | |||||
| "attribute(%s), args(%s) and " | |||||
| "workspace limit (%zu bytes)", | "workspace limit (%zu bytes)", | ||||
| Algorithm::attribute_str(attr).c_str(), | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | args.to_string().c_str(), workspace_limit_in_bytes)); | ||||
| } | } | ||||
| @@ -176,11 +186,12 @@ AlgoBwdData* BwdData::get_algorithm_heuristic( | |||||
| const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | auto fm = make_canonized_filter_meta(im.ndim, filter, offset); | ||||
| return get_algorithm_heuristic(im, fm, offset, mask, out_grad, im_grad, | |||||
| offset_grad, mask_grad, | |||||
| workspace_limit_in_bytes, attr); | |||||
| return get_algorithm_heuristic( | |||||
| im, fm, offset, mask, out_grad, im_grad, offset_grad, mask_grad, | |||||
| workspace_limit_in_bytes, positive_attr, negative_attr); | |||||
| } | } | ||||
| AlgoBwdData* BwdData::get_algorithm_heuristic( | AlgoBwdData* BwdData::get_algorithm_heuristic( | ||||
| @@ -188,18 +199,21 @@ AlgoBwdData* BwdData::get_algorithm_heuristic( | |||||
| const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad, | AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad, | ||||
| offset_grad, mask_grad); | offset_grad, mask_grad); | ||||
| if (sm_algo_pack.algo_matmul.is_available_attribute( | if (sm_algo_pack.algo_matmul.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.algo_matmul; | return &sm_algo_pack.algo_matmul; | ||||
| } | } | ||||
| megdnn_throw( | megdnn_throw( | ||||
| ssprintf("no deformable conv bwd data algorithm with attribute%s, " | |||||
| ssprintf("no deformable conv bwd data algorithm without " | |||||
| "attribute(%s) with attribute(%s), " | |||||
| "args(%s) and " | "args(%s) and " | ||||
| "workspace limit (%zu bytes)", | "workspace limit (%zu bytes)", | ||||
| Algorithm::attribute_str(attr).c_str(), | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | args.to_string().c_str(), workspace_limit_in_bytes)); | ||||
| } | } | ||||
| @@ -36,7 +36,8 @@ public: | |||||
| const TensorLayout& mask, | const TensorLayout& mask, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -54,13 +55,12 @@ protected: | |||||
| const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& im, const TensorLayout& filter, | |||||
| const TensorLayout& offset, const TensorLayout& mask, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -81,7 +81,8 @@ public: | |||||
| const TensorLayout& out_grad, | const TensorLayout& out_grad, | ||||
| const CanonizedFilterMeta& filter_grad, | const CanonizedFilterMeta& filter_grad, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& im, | size_t get_workspace_in_bytes(const TensorLayout& im, | ||||
| const TensorLayout& offset, | const TensorLayout& offset, | ||||
| @@ -105,13 +106,12 @@ protected: | |||||
| const TensorLayout& mask, const TensorLayout& out_grad, | const TensorLayout& mask, const TensorLayout& out_grad, | ||||
| const TensorLayout& filter_grad) override; | const TensorLayout& filter_grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& im, | |||||
| const TensorLayout& offset, | |||||
| const TensorLayout& mask, | |||||
| const TensorLayout& out_grad, | |||||
| const TensorLayout& filter_grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& im, const TensorLayout& offset, | |||||
| const TensorLayout& mask, const TensorLayout& out_grad, | |||||
| const TensorLayout& filter_grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -132,7 +132,8 @@ public: | |||||
| const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr); | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& im, | size_t get_workspace_in_bytes(const TensorLayout& im, | ||||
| const TensorLayout& filter, | const TensorLayout& filter, | ||||
| @@ -166,8 +167,8 @@ protected: | |||||
| const TensorLayout& offset, const TensorLayout& mask, | const TensorLayout& offset, const TensorLayout& mask, | ||||
| const TensorLayout& out_grad, const TensorLayout& im_grad, | const TensorLayout& out_grad, const TensorLayout& im_grad, | ||||
| const TensorLayout& offset_grad, const TensorLayout& mask_grad, | const TensorLayout& offset_grad, const TensorLayout& mask_grad, | ||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -61,9 +61,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -61,9 +61,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -62,9 +62,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -20,30 +20,32 @@ using namespace cuda; | |||||
| /* ============== LocalShareForwardImpl ============== */ | /* ============== LocalShareForwardImpl ============== */ | ||||
| LocalShareForwardImpl::Algorithm* | LocalShareForwardImpl::Algorithm* | ||||
| LocalShareForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) { | |||||
| LocalShareForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, filter, dst); | AlgoBase::SizeArgs args(this, src, filter, dst); | ||||
| if (sm_algo_pack.batch_size_aware_chwn_small_image | if (sm_algo_pack.batch_size_aware_chwn_small_image | ||||
| .is_available_attribute(args, attr, | |||||
| .is_available_attribute(args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) { | workspace_limit_in_bytes)) { | ||||
| return &sm_algo_pack.batch_size_aware_chwn_small_image; | return &sm_algo_pack.batch_size_aware_chwn_small_image; | ||||
| } | } | ||||
| if (sm_algo_pack.batch_size_aware_chwn.is_available_attribute( | if (sm_algo_pack.batch_size_aware_chwn.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.batch_size_aware_chwn; | return &sm_algo_pack.batch_size_aware_chwn; | ||||
| } | } | ||||
| if (sm_algo_pack.batched_matmul.is_available_attribute( | if (sm_algo_pack.batched_matmul.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
| } | } | ||||
| megdnn_throw(ssprintf( | |||||
| "no local share conv algorithm with attribute%s, args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
| workspace_limit_in_bytes)); | |||||
| megdnn_throw( | |||||
| ssprintf("no local share conv algorithm without attribute(%s) with " | |||||
| "attribute(%s), args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
| } | } | ||||
| std::vector<LocalShareForwardImpl::Algorithm*> | std::vector<LocalShareForwardImpl::Algorithm*> | ||||
| @@ -79,21 +81,24 @@ LocalShareBackwardDataImpl::Algorithm* | |||||
| LocalShareBackwardDataImpl::get_algorithm_heuristic( | LocalShareBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
| if (sm_algo_pack.implicit_gemm.is_available_attribute( | if (sm_algo_pack.implicit_gemm.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.implicit_gemm; | return &sm_algo_pack.implicit_gemm; | ||||
| } | } | ||||
| if (sm_algo_pack.batched_matmul.is_available_attribute( | if (sm_algo_pack.batched_matmul.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
| } | } | ||||
| megdnn_throw(ssprintf( | |||||
| "no local share bwd data algorithm with attribute%s args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(attr).c_str(), args.to_string().c_str(), | |||||
| workspace_limit_in_bytes)); | |||||
| megdnn_throw( | |||||
| ssprintf("no local share bwd data algorithm without attribute(%s) " | |||||
| "with attribute(%s) args(%s) and " | |||||
| "workspace limit (%zu bytes)", | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | |||||
| } | } | ||||
| std::vector<LocalShareBackwardDataImpl::Algorithm*> | std::vector<LocalShareBackwardDataImpl::Algorithm*> | ||||
| @@ -129,21 +134,24 @@ LocalShareBackwardFilterImpl::Algorithm* | |||||
| LocalShareBackwardFilterImpl::get_algorithm_heuristic( | LocalShareBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
| if (sm_algo_pack.implicit_gemm.is_available_attribute( | if (sm_algo_pack.implicit_gemm.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.implicit_gemm; | return &sm_algo_pack.implicit_gemm; | ||||
| } | } | ||||
| if (sm_algo_pack.batched_matmul.is_available_attribute( | if (sm_algo_pack.batched_matmul.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.batched_matmul; | return &sm_algo_pack.batched_matmul; | ||||
| } | } | ||||
| megdnn_throw( | megdnn_throw( | ||||
| ssprintf("no local share bwd filter algorithm with attribute%s, " | |||||
| ssprintf("no local share bwd filter algorithm without " | |||||
| "attribute(%s) with attribute(%s), " | |||||
| "args(%s) and " | "args(%s) and " | ||||
| "workspace limit (%zu bytes)", | "workspace limit (%zu bytes)", | ||||
| Algorithm::attribute_str(attr).c_str(), | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| args.to_string().c_str(), workspace_limit_in_bytes)); | args.to_string().c_str(), workspace_limit_in_bytes)); | ||||
| } | } | ||||
| @@ -39,11 +39,12 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -71,11 +72,11 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -104,11 +105,11 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -85,9 +85,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -30,35 +30,30 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
| if (sm_algo_pack.cublas.is_available_attribute(args, attr, | |||||
| workspace_limit_in_bytes)) { | |||||
| if (sm_algo_pack.cublas.is_available_attribute( | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.cublas; | return &sm_algo_pack.cublas; | ||||
| } | } | ||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| if (sm_algo_pack.cublas_lt.is_available_attribute( | if (sm_algo_pack.cublas_lt.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.cublas_lt; | return &sm_algo_pack.cublas_lt; | ||||
| } | } | ||||
| #endif | #endif | ||||
| #if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
| if (sm_algo_pack.wmma_uint4x4x32.is_available_attribute( | if (sm_algo_pack.wmma_uint4x4x32.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.wmma_uint4x4x32; | return &sm_algo_pack.wmma_uint4x4x32; | ||||
| } | } | ||||
| #endif | #endif | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | ||||
| @@ -57,11 +57,10 @@ protected: | |||||
| std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, | ||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| private: | private: | ||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| @@ -65,9 +65,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -31,21 +31,16 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
| if (sm_algo_pack.algo_default.is_available_attribute( | if (sm_algo_pack.algo_default.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.algo_default; | return &sm_algo_pack.algo_default; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | ||||
| @@ -36,11 +36,11 @@ private: | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
| return "FALLBACK BATCHED MATMUL"; | return "FALLBACK BATCHED MATMUL"; | ||||
| @@ -280,20 +280,23 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); | ||||
| auto result = get_algorithm_heuristic_with_ncb( | auto result = get_algorithm_heuristic_with_ncb( | ||||
| fparam, workspace_limit_in_bytes, attr); | |||||
| fparam, workspace_limit_in_bytes, positive_attr, negative_attr); | |||||
| if (result == nullptr) { | if (result == nullptr) { | ||||
| result = naive::ConvBiasForwardImpl::get_algorithm_heuristic( | result = naive::ConvBiasForwardImpl::get_algorithm_heuristic( | ||||
| src, filter, bias, z, dst, workspace_limit_in_bytes, attr); | |||||
| src, filter, bias, z, dst, workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| return result; | return result; | ||||
| } | } | ||||
| ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo_data_type = param.deduce_algo_data_type(); | auto algo_data_type = param.deduce_algo_data_type(); | ||||
| auto suggest_category_order = suggest_algo_category_order(param); | auto suggest_category_order = suggest_algo_category_order(param); | ||||
| for (auto category : suggest_category_order) { | for (auto category : suggest_category_order) { | ||||
| @@ -301,7 +304,8 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | |||||
| ConvBiasImpl::Algorithm* heuristic_algo = nullptr; | ConvBiasImpl::Algorithm* heuristic_algo = nullptr; | ||||
| for (auto i : origin_algos) { | for (auto i : origin_algos) { | ||||
| bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute( | bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute( | ||||
| param, AlgoSelectionStrategy::HEURISTIC, attr); | |||||
| param, AlgoSelectionStrategy::HEURISTIC, positive_attr, | |||||
| negative_attr); | |||||
| if (usable_attribute && | if (usable_attribute && | ||||
| static_cast<AlgoBase*>(i)->get_workspace(param) <= | static_cast<AlgoBase*>(i)->get_workspace(param) <= | ||||
| workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
| @@ -497,7 +501,8 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||||
| if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
| memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | ||||
| m_prev_selected_algo = get_algorithm_heuristic_with_ncb( | m_prev_selected_algo = get_algorithm_heuristic_with_ncb( | ||||
| param, workspace_size, AlgoAttribute::DEFAULT); | |||||
| param, workspace_size, AlgoAttribute::DEFAULT, | |||||
| AlgoAttribute::DEFAULT); | |||||
| m_prev_selected_algo_sizep = param; | m_prev_selected_algo_sizep = param; | ||||
| } | } | ||||
| return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
| @@ -89,13 +89,12 @@ public: | |||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| //! implemented by get_algorithm_heuristic_with_ncb() | //! implemented by get_algorithm_heuristic_with_ncb() | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| //! size param for kernels with non-contiguous batch | //! size param for kernels with non-contiguous batch | ||||
| struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { | struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { | ||||
| @@ -319,11 +318,14 @@ public: | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool usable_attribute( | |||||
| const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy, | |||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const { | |||||
| return contain_attribute(attr) && | |||||
| bool usable_attribute(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy, | |||||
| const AlgoAttribute& positive_attr = | |||||
| AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = | |||||
| AlgoAttribute::DEFAULT) const { | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
| } | } | ||||
| @@ -361,7 +363,8 @@ protected: | |||||
| virtual Algorithm* get_algorithm_heuristic_with_ncb( | virtual Algorithm* get_algorithm_heuristic_with_ncb( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -198,13 +198,15 @@ std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms( | |||||
| ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | ||||
| auto result = get_algorithm_heuristic_with_ncb( | auto result = get_algorithm_heuristic_with_ncb( | ||||
| fparam, workspace_limit_in_bytes, attr); | |||||
| fparam, workspace_limit_in_bytes, positive_attr, negative_attr); | |||||
| if (result == nullptr) { | if (result == nullptr) { | ||||
| result = naive::ConvolutionForwardImpl::get_algorithm_heuristic( | result = naive::ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
| src, filter, dst, workspace_limit_in_bytes, attr); | |||||
| src, filter, dst, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr); | |||||
| } | } | ||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -312,7 +314,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||||
| ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo_data_type = param.deduce_algo_data_type(); | auto algo_data_type = param.deduce_algo_data_type(); | ||||
| auto suggest_category_order = suggest_algo_category_order(param); | auto suggest_category_order = suggest_algo_category_order(param); | ||||
| for (auto category : suggest_category_order) { | for (auto category : suggest_category_order) { | ||||
| @@ -320,7 +323,8 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | |||||
| ConvolutionImpl::Algorithm* heuristic_algo = nullptr; | ConvolutionImpl::Algorithm* heuristic_algo = nullptr; | ||||
| for (auto i : origin_algos) { | for (auto i : origin_algos) { | ||||
| bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute( | bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute( | ||||
| param, AlgoSelectionStrategy::HEURISTIC, attr); | |||||
| param, AlgoSelectionStrategy::HEURISTIC, positive_attr, | |||||
| negative_attr); | |||||
| if (usable_attribute && | if (usable_attribute && | ||||
| static_cast<AlgoBase*>(i)->get_workspace(param) <= | static_cast<AlgoBase*>(i)->get_workspace(param) <= | ||||
| workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
| @@ -391,7 +395,8 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | |||||
| if (!m_prev_selected_algo || | if (!m_prev_selected_algo || | ||||
| memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | ||||
| m_prev_selected_algo = get_algorithm_heuristic_with_ncb( | m_prev_selected_algo = get_algorithm_heuristic_with_ncb( | ||||
| param, workspace_size, AlgoAttribute::DEFAULT); | |||||
| param, workspace_size, AlgoAttribute::DEFAULT, | |||||
| AlgoAttribute::DEFAULT); | |||||
| m_prev_selected_algo_sizep = param; | m_prev_selected_algo_sizep = param; | ||||
| } | } | ||||
| return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
| @@ -513,15 +518,17 @@ ConvolutionBackwardDataImpl::Algorithm* | |||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| if (param().format == param::Convolution::Format::NHWCD4 || | if (param().format == param::Convolution::Format::NHWCD4 || | ||||
| param().format == param::Convolution::Format::NCHW4) { | param().format == param::Convolution::Format::NCHW4) { | ||||
| return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( | return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| filter, diff, grad, workspace_limit_in_bytes, attr); | |||||
| filter, diff, grad, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr); | |||||
| } | } | ||||
| auto fparam = make_ncb_kern_size_param(filter, diff, grad); | auto fparam = make_ncb_kern_size_param(filter, diff, grad); | ||||
| return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes, | return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| ConvolutionBackwardDataImpl::NCBKernSizeParam | ConvolutionBackwardDataImpl::NCBKernSizeParam | ||||
| @@ -666,15 +673,16 @@ ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb( | |||||
| ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb( | ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| if (param.filter_meta.group != 1) { | if (param.filter_meta.group != 1) { | ||||
| auto p1g = param; | auto p1g = param; | ||||
| p1g.filter_meta.group = 1; | p1g.filter_meta.group = 1; | ||||
| return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes, | return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes, | return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | ||||
| @@ -729,10 +737,12 @@ ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||||
| ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
| ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| for (auto i : ncb_1g_get_all_algorithms(param)) { | for (auto i : ncb_1g_get_all_algorithms(param)) { | ||||
| if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) { | if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) { | ||||
| if (i->contain_attribute(attr)) { | |||||
| if (i->contain_attribute_all(positive_attr) && | |||||
| !i->contain_attribute_any(negative_attr)) { | |||||
| return i; | return i; | ||||
| } | } | ||||
| } | } | ||||
| @@ -783,7 +793,7 @@ ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { | |||||
| memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | memcmp(&m_prev_selected_algo_sizep, ¶m, sizeof(NCBKernSizeParam))) { | ||||
| m_prev_selected_algo = ncb_1g_get_algorithm_heuristic( | m_prev_selected_algo = ncb_1g_get_algorithm_heuristic( | ||||
| param, std::numeric_limits<size_t>::max(), | param, std::numeric_limits<size_t>::max(), | ||||
| AlgoAttribute::DEFAULT); | |||||
| AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||||
| m_prev_selected_algo_sizep = param; | m_prev_selected_algo_sizep = param; | ||||
| } | } | ||||
| return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
| @@ -86,11 +86,11 @@ public: | |||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| //! implemented by get_algorithm_heuristic_with_ncb() | //! implemented by get_algorithm_heuristic_with_ncb() | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| //! size param for kernels with non-contiguous batch | //! size param for kernels with non-contiguous batch | ||||
| struct NCBKernSizeParam { | struct NCBKernSizeParam { | ||||
| @@ -238,11 +238,14 @@ public: | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool usable_attribute( | |||||
| const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy, | |||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const { | |||||
| return contain_attribute(attr) && | |||||
| bool usable_attribute(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy, | |||||
| const AlgoAttribute& positive_attr = | |||||
| AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = | |||||
| AlgoAttribute::DEFAULT) const { | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
| } | } | ||||
| @@ -272,7 +275,8 @@ protected: | |||||
| virtual Algorithm* get_algorithm_heuristic_with_ncb( | virtual Algorithm* get_algorithm_heuristic_with_ncb( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| @@ -322,11 +326,11 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
| //! size param for kernels with non-contiguous batch | //! size param for kernels with non-contiguous batch | ||||
| @@ -421,10 +425,14 @@ protected: | |||||
| virtual ncb_kern_t dispatch_kern( | virtual ncb_kern_t dispatch_kern( | ||||
| ConvolutionBackwardDataImpl* opr, | ConvolutionBackwardDataImpl* opr, | ||||
| const NCBKernSizeParam& param) const = 0; | const NCBKernSizeParam& param) const = 0; | ||||
| bool usable_attribute( | |||||
| ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param, | |||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) const { | |||||
| return contain_attribute(attr) && usable(opr, param); | |||||
| bool usable_attribute(ConvolutionBackwardDataImpl* opr, | |||||
| const NCBKernSizeParam& param, | |||||
| const AlgoAttribute& positive_attr = | |||||
| AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = | |||||
| AlgoAttribute::DEFAULT) const { | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && usable(opr, param); | |||||
| } | } | ||||
| virtual bool is_preferred(const NCBKernSizeParam&) const { | virtual bool is_preferred(const NCBKernSizeParam&) const { | ||||
| return false; | return false; | ||||
| @@ -449,7 +457,8 @@ protected: | |||||
| //! default impl calls ncb_1g_get_algorithm_heuristic() | //! default impl calls ncb_1g_get_algorithm_heuristic() | ||||
| virtual Algorithm* get_algorithm_heuristic_with_ncb( | virtual Algorithm* get_algorithm_heuristic_with_ncb( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| //! get kernel pointer for float32 non-contiguous batch 1-group kernel | //! get kernel pointer for float32 non-contiguous batch 1-group kernel | ||||
| virtual ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, | virtual ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, | ||||
| @@ -467,7 +476,8 @@ protected: | |||||
| */ | */ | ||||
| virtual Algorithm* ncb_1g_get_algorithm_heuristic( | virtual Algorithm* ncb_1g_get_algorithm_heuristic( | ||||
| const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); | static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); | ||||
| /** | /** | ||||
| @@ -131,20 +131,24 @@ MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( | |||||
| MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto kern_size_param = make_kern_size_param(A, B, C); | auto kern_size_param = make_kern_size_param(A, B, C); | ||||
| if (auto algo = static_cast<AlgoBase*>( | if (auto algo = static_cast<AlgoBase*>( | ||||
| get_algorithm_from_desc(execution_policy().algo))) { | get_algorithm_from_desc(execution_policy().algo))) { | ||||
| megdnn_assert(algo->get_workspace(kern_size_param) < | megdnn_assert(algo->get_workspace(kern_size_param) < | ||||
| workspace_limit_in_bytes); | workspace_limit_in_bytes); | ||||
| auto cur = megdnn::get_algo_with_attribute<MatrixMulImpl>(algo, attr); | |||||
| auto cur = megdnn::get_algo_match_attribute<MatrixMulImpl>( | |||||
| algo, positive_attr, negative_attr); | |||||
| if (cur) | if (cur) | ||||
| return cur; | return cur; | ||||
| megdnn_throw(ssprintf( | |||||
| "require algorithm with attribute%s, but given algorithm with " | |||||
| "attribute%s", | |||||
| Algorithm::attribute_str(attr).c_str(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str())); | |||||
| megdnn_throw( | |||||
| ssprintf("require algorithm without attribute(%s) with " | |||||
| "attribute(%s), but given algorithm with " | |||||
| "attribute(%s)", | |||||
| Algorithm::attribute_str(negative_attr).c_str(), | |||||
| Algorithm::attribute_str(positive_attr).c_str(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str())); | |||||
| } | } | ||||
| AlgoTypePack algo_type; | AlgoTypePack algo_type; | ||||
| algo_type.data_type = kern_size_param.deduce_algo_data_type(); | algo_type.data_type = kern_size_param.deduce_algo_data_type(); | ||||
| @@ -157,7 +161,7 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||||
| static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | ||||
| workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
| if (static_cast<AlgoBase*>(algo)->preferred_attribute( | if (static_cast<AlgoBase*>(algo)->preferred_attribute( | ||||
| kern_size_param, attr)) { | |||||
| kern_size_param, positive_attr, negative_attr)) { | |||||
| //! use gemv algo if it's prefered | //! use gemv algo if it's prefered | ||||
| if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | ||||
| return algo; | return algo; | ||||
| @@ -215,9 +219,9 @@ MatrixMulImpl::KernParam MatrixMulImpl::make_kern_param( | |||||
| size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A, | size_t MatrixMulImpl::get_workspace_in_bytes(const TensorLayout& A, | ||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) { | const TensorLayout& C) { | ||||
| if (auto algo = get_algorithm_heuristic(A, B, C, | |||||
| std::numeric_limits<size_t>::max(), | |||||
| AlgoAttribute::DEFAULT)) { | |||||
| if (auto algo = get_algorithm_heuristic( | |||||
| A, B, C, std::numeric_limits<size_t>::max(), | |||||
| AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT)) { | |||||
| auto kern_size_param = make_kern_size_param(A, B, C); | auto kern_size_param = make_kern_size_param(A, B, C); | ||||
| return static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param); | return static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param); | ||||
| } | } | ||||
| @@ -230,6 +234,7 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| if (auto algo = get_algorithm_heuristic(A.layout, B.layout, C.layout, | if (auto algo = get_algorithm_heuristic(A.layout, B.layout, C.layout, | ||||
| std::numeric_limits<size_t>::max(), | std::numeric_limits<size_t>::max(), | ||||
| AlgoAttribute::DEFAULT, | |||||
| AlgoAttribute::DEFAULT)) { | AlgoAttribute::DEFAULT)) { | ||||
| auto kern_param = make_kern_param(A, B, C, workspace); | auto kern_param = make_kern_param(A, B, C, workspace); | ||||
| auto kern = static_cast<AlgoBase*>(algo)->get_kern(kern_param); | auto kern = static_cast<AlgoBase*>(algo)->get_kern(kern_param); | ||||
| @@ -225,8 +225,11 @@ public: | |||||
| }; | }; | ||||
| bool preferred_attribute( | bool preferred_attribute( | ||||
| const KernSizeParam& param, | const KernSizeParam& param, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE) { | |||||
| return contain_attribute(attr) && preferred(param); | |||||
| const AlgoAttribute& positive_attr = | |||||
| AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && preferred(param); | |||||
| }; | }; | ||||
| virtual MatmulDescription matmul_description() const = 0; | virtual MatmulDescription matmul_description() const = 0; | ||||
| @@ -267,12 +270,10 @@ protected: | |||||
| const TensorLayout& B, | const TensorLayout& B, | ||||
| const TensorLayout& C) override; | const TensorLayout& C) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| }; | }; | ||||
| } // namespace fallback | } // namespace fallback | ||||
| @@ -125,14 +125,11 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& /* bias */, const TensorLayout& /* z */, | const TensorLayout& /* bias */, const TensorLayout& /* z */, | ||||
| const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */ | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */ | ||||
| , | , | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
| ->default_batch_conv_bias_fwd_algo(); | ->default_batch_conv_bias_fwd_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -31,13 +31,12 @@ public: | |||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
| @@ -76,7 +76,8 @@ BatchedMatrixMulForward::Algorithm* | |||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
| const AlgoAttribute& /*attr*/) { | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) { | |||||
| return static_cast<HandleImpl*>(handle()) | return static_cast<HandleImpl*>(handle()) | ||||
| ->default_batched_matmul_fwd_algo(); | ->default_batched_matmul_fwd_algo(); | ||||
| } | } | ||||
| @@ -28,11 +28,11 @@ public: | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
| @@ -246,14 +246,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
| const TensorLayout& /* bias */, const TensorLayout& /* z */, | const TensorLayout& /* bias */, const TensorLayout& /* z */, | ||||
| const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = | auto algo = | ||||
| static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -31,13 +31,12 @@ public: | |||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& bias, | |||||
| const TensorLayout& z, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& bias, const TensorLayout& z, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| size_t get_workspace_in_bytes( | size_t get_workspace_in_bytes( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| @@ -272,14 +272,11 @@ ConvolutionForwardImpl:: get_all_algorithms(const TensorLayout &, | |||||
| ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = | auto algo = | ||||
| static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -302,14 +299,11 @@ ConvolutionBackwardData::Algorithm* | |||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = | auto algo = | ||||
| static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -333,14 +327,11 @@ ConvolutionBackwardFilter::Algorithm* | |||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = | auto algo = | ||||
| static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -25,11 +25,11 @@ class ConvolutionForwardImpl: public ConvolutionForward { | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | ||||
| const TensorLayout &filter, | const TensorLayout &filter, | ||||
| const TensorLayout &dst) override; | const TensorLayout &dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&, | const TensorLayout&, | ||||
| const PreprocessedFilter*) override { | const PreprocessedFilter*) override { | ||||
| @@ -67,11 +67,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter, | ||||
| const TensorLayout &diff, | const TensorLayout &diff, | ||||
| const TensorLayout &grad) override; | const TensorLayout &grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&) override; | const TensorLayout&) override; | ||||
| @@ -90,11 +90,11 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||||
| std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src, | ||||
| const TensorLayout &diff, | const TensorLayout &diff, | ||||
| const TensorLayout &grad) override; | const TensorLayout &grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&) override; | const TensorLayout&) override; | ||||
| @@ -120,13 +120,10 @@ Convolution3DForward::Algorithm* | |||||
| Convolution3DForwardImpl::get_algorithm_heuristic( | Convolution3DForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | const TensorLayout& /* src */, const TensorLayout& /* filter */, | ||||
| const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -150,14 +147,11 @@ Convolution3DBackwardData::Algorithm* | |||||
| Convolution3DBackwardDataImpl::get_algorithm_heuristic( | Convolution3DBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = | auto algo = | ||||
| static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -183,14 +177,11 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */ | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */ | ||||
| , | , | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
| ->default_conv3d_bwd_filter_algo(); | ->default_conv3d_bwd_filter_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -22,11 +22,11 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&) override { | const TensorLayout&) override { | ||||
| return 0; | return 0; | ||||
| @@ -44,11 +44,11 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&) override { | const TensorLayout&) override { | ||||
| return 0; | return 0; | ||||
| @@ -66,11 +66,11 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
| const TensorLayout&) override { | const TensorLayout&) override { | ||||
| return 0; | return 0; | ||||
| @@ -26,13 +26,13 @@ public: | |||||
| return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
| }; | }; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /* src */, | |||||
| const TensorLayout& /* filter */, | |||||
| const TensorLayout& /* offset */, | |||||
| const TensorLayout& /* mask */, | |||||
| const TensorLayout& /* dst */, | |||||
| size_t /* workspace_limit_in_bytes */, | |||||
| const AlgoAttribute& /*attr*/) override { | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /* src */, const TensorLayout& /* filter */, | |||||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||||
| const TensorLayout& /* dst */, | |||||
| size_t /* workspace_limit_in_bytes */, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override { | |||||
| return nullptr; | return nullptr; | ||||
| }; | }; | ||||
| @@ -68,13 +68,13 @@ public: | |||||
| return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
| }; | }; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /* im */, | |||||
| const TensorLayout& /* offset */, | |||||
| const TensorLayout& /* mask */, | |||||
| const TensorLayout& /* out_grad */, | |||||
| const TensorLayout& /* filter_grad */, | |||||
| size_t /* workspace_limit_in_bytes */, | |||||
| const AlgoAttribute& /*attr*/) override { | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /* im */, const TensorLayout& /* offset */, | |||||
| const TensorLayout& /* mask */, const TensorLayout& /* out_grad */, | |||||
| const TensorLayout& /* filter_grad */, | |||||
| size_t /* workspace_limit_in_bytes */, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override { | |||||
| return nullptr; | return nullptr; | ||||
| }; | }; | ||||
| @@ -112,16 +112,16 @@ public: | |||||
| return std::vector<Algorithm*>(); | return std::vector<Algorithm*>(); | ||||
| }; | }; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /* im */, | |||||
| const TensorLayout& /* filter */, | |||||
| const TensorLayout& /* offset */, | |||||
| const TensorLayout& /* mask */, | |||||
| const TensorLayout& /* out_grad */, | |||||
| const TensorLayout& /* im_grad */, | |||||
| const TensorLayout& /* offset_grad */, | |||||
| const TensorLayout& /* mask_grad */, | |||||
| size_t /* workspace_limit_in_bytes */, | |||||
| const AlgoAttribute& /*attr*/) override { | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /* im */, const TensorLayout& /* filter */, | |||||
| const TensorLayout& /* offset */, const TensorLayout& /* mask */, | |||||
| const TensorLayout& /* out_grad */, | |||||
| const TensorLayout& /* im_grad */, | |||||
| const TensorLayout& /* offset_grad */, | |||||
| const TensorLayout& /* mask_grad */, | |||||
| size_t /* workspace_limit_in_bytes */, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override { | |||||
| return nullptr; | return nullptr; | ||||
| }; | }; | ||||
| @@ -162,14 +162,11 @@ LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, | |||||
| LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = | auto algo = | ||||
| static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -194,14 +191,11 @@ LocalShareBackwardData::Algorithm* | |||||
| LocalShareBackwardDataImpl::get_algorithm_heuristic( | LocalShareBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* filter */, const TensorLayout& /* diff */, | const TensorLayout& /* filter */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
| ->default_local_share_bwd_data_algo(); | ->default_local_share_bwd_data_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -226,14 +220,11 @@ LocalShareBackwardFilter::Algorithm* | |||||
| LocalShareBackwardFilterImpl::get_algorithm_heuristic( | LocalShareBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /* src */, const TensorLayout& /* diff */, | const TensorLayout& /* src */, const TensorLayout& /* diff */, | ||||
| const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
| ->default_local_share_bwd_filter_algo(); | ->default_local_share_bwd_filter_algo(); | ||||
| megdnn_assert(algo->contain_attribute(attr), | |||||
| "require algorithm with attribute%s, but heuristic " | |||||
| "algorithm(%s) with attribute%s ", | |||||
| Algorithm::attribute_str(attr).c_str(), algo->name(), | |||||
| Algorithm::attribute_str(algo->attribute()).c_str()); | |||||
| algo->check_attribute(positive_attr, negative_attr); | |||||
| return algo; | return algo; | ||||
| } | } | ||||
| @@ -30,11 +30,11 @@ public: | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | ||||
| const TensorLayout& /*dst*/) override; | const TensorLayout& /*dst*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | |||||
| const TensorLayout& /*filter*/, | |||||
| const TensorLayout& /*dst*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*filter*/, | |||||
| const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
| @@ -55,11 +55,11 @@ public: | |||||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | ||||
| const TensorLayout& /*grad*/) override; | const TensorLayout& /*grad*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*filter*/, | |||||
| const TensorLayout& /*diff*/, | |||||
| const TensorLayout& /*grad*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, | |||||
| const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
| @@ -80,11 +80,11 @@ public: | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | ||||
| const TensorLayout& /*grad*/) override; | const TensorLayout& /*grad*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, | |||||
| const TensorLayout& /*diff*/, | |||||
| const TensorLayout& /*grad*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*src*/, const TensorLayout& /*diff*/, | |||||
| const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
| @@ -91,7 +91,8 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
| MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | ||||
| const AlgoAttribute& /*attr*/) { | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) { | |||||
| return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | ||||
| } | } | ||||
| @@ -29,11 +29,11 @@ public: | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; | ||||
| @@ -72,9 +72,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -32,21 +32,17 @@ BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| BatchedMatrixMulForwardImpl::Algorithm* | BatchedMatrixMulForwardImpl::Algorithm* | ||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
| if (sm_algo_pack.blas.is_available_attribute(args, attr, | |||||
| workspace_limit_in_bytes)) { | |||||
| if (sm_algo_pack.blas.is_available_attribute( | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.blas; | return &sm_algo_pack.blas; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<BatchedMatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "batched matrix mul forward", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | size_t BatchedMatrixMulForwardImpl::get_workspace_in_bytes( | ||||
| @@ -36,11 +36,11 @@ private: | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
| return "ROCM BATCHED MATMUL"; | return "ROCM BATCHED MATMUL"; | ||||
| @@ -76,9 +76,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -73,9 +73,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -75,9 +75,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| @@ -29,40 +29,43 @@ using namespace rocm; | |||||
| /* ============== ConvolutionForwardImpl ============== */ | /* ============== ConvolutionForwardImpl ============== */ | ||||
| ConvolutionForwardImpl::Algorithm* | ConvolutionForwardImpl::Algorithm* | ||||
| ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) { | |||||
| ConvolutionForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(src, filter, dst); | auto fm = check_layout_fwd(src, filter, dst); | ||||
| return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| ConvolutionForwardImpl::Algorithm* | ConvolutionForwardImpl::Algorithm* | ||||
| ConvolutionForwardImpl::get_algorithm_heuristic( | ConvolutionForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const CanonizedFilterMeta& filter, | const TensorLayout& src, const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, filter, dst); | AlgoBase::SizeArgs args(this, src, filter, dst); | ||||
| //! MIOpen auto-tuning need to run with actual tensors, so we cannot get | //! MIOpen auto-tuning need to run with actual tensors, so we cannot get | ||||
| //! best algorithm here. | //! best algorithm here. | ||||
| if (is_miopen_supported(args)) { | if (is_miopen_supported(args)) { | ||||
| auto algo = megdnn::get_algo_with_attribute<ConvolutionForwardImpl>( | |||||
| sm_algo_pack.miopen_algos[0], attr); | |||||
| auto algo = megdnn::get_algo_match_attribute<ConvolutionForwardImpl>( | |||||
| sm_algo_pack.miopen_algos[0], positive_attr, negative_attr); | |||||
| if (algo) | if (algo) | ||||
| return algo; | return algo; | ||||
| } | } | ||||
| if (args.filter_meta.group > 1) { | if (args.filter_meta.group > 1) { | ||||
| if (sm_algo_pack.chanwise.is_available_attribute( | if (sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| } | } | ||||
| auto prefer_1x1 = [&args, attr, workspace_limit_in_bytes]() { | |||||
| auto prefer_1x1 = [&args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes]() { | |||||
| const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4; | const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4; | ||||
| size_t batch_size = args.src_layout->shape[0]; | size_t batch_size = args.src_layout->shape[0]; | ||||
| @@ -70,14 +73,15 @@ ConvolutionForwardImpl::get_algorithm_heuristic( | |||||
| return false; | return false; | ||||
| } | } | ||||
| return sm_algo_pack.a1x1.is_available_attribute( | return sm_algo_pack.a1x1.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes); | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes); | |||||
| }; | }; | ||||
| if (prefer_1x1()) { | if (prefer_1x1()) { | ||||
| return &sm_algo_pack.a1x1; | return &sm_algo_pack.a1x1; | ||||
| } | } | ||||
| auto prefer_1x1_large_batch = [&args, attr, workspace_limit_in_bytes]() { | |||||
| auto prefer_1x1_large_batch = [&args, positive_attr, negative_attr, | |||||
| workspace_limit_in_bytes]() { | |||||
| const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32; | const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32; | ||||
| size_t batch_size = args.src_layout->shape[0]; | size_t batch_size = args.src_layout->shape[0]; | ||||
| @@ -85,22 +89,16 @@ ConvolutionForwardImpl::get_algorithm_heuristic( | |||||
| return false; | return false; | ||||
| } | } | ||||
| return sm_algo_pack.batched_matrix_mul.is_available_attribute( | return sm_algo_pack.batched_matrix_mul.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes); | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes); | |||||
| }; | }; | ||||
| if (prefer_1x1_large_batch()) { | if (prefer_1x1_large_batch()) { | ||||
| return &sm_algo_pack.batched_matrix_mul; | return &sm_algo_pack.batched_matrix_mul; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvolutionForwardImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv fwd", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionForwardImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv fwd"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionForwardImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv fwd", positive_attr, negative_attr); | |||||
| } | } | ||||
| std::vector<ConvolutionForwardImpl::Algorithm*> | std::vector<ConvolutionForwardImpl::Algorithm*> | ||||
| @@ -156,41 +154,39 @@ ConvolutionBackwardDataImpl::Algorithm* | |||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(grad, filter, diff); | auto fm = check_layout_fwd(grad, filter, diff); | ||||
| return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| ConvolutionBackwardDataImpl::Algorithm* | ConvolutionBackwardDataImpl::Algorithm* | ||||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ConvolutionBackwardDataImpl::get_algorithm_heuristic( | ||||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, filter, diff, grad); | AlgoBase::SizeArgs args(this, filter, diff, grad); | ||||
| if (is_miopen_supported(args.as_fwd_args())) { | if (is_miopen_supported(args.as_fwd_args())) { | ||||
| auto algo = megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.miopen_algos[0], attr); | |||||
| auto algo = | |||||
| megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.miopen_algos[0], positive_attr, | |||||
| negative_attr); | |||||
| if (algo) | if (algo) | ||||
| return algo; | return algo; | ||||
| } | } | ||||
| if (args.filter_meta.group > 1 && | if (args.filter_meta.group > 1 && | ||||
| sm_algo_pack.chanwise.is_available_attribute( | sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv bwd_data", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv bwd_data"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionBackwardDataImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv bwd_data", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | ||||
| @@ -229,43 +225,40 @@ ConvolutionBackwardFilterImpl::Algorithm* | |||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| auto fm = check_layout_fwd(src, grad, diff); | auto fm = check_layout_fwd(src, grad, diff); | ||||
| return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | ||||
| attr); | |||||
| positive_attr, negative_attr); | |||||
| } | } | ||||
| ConvolutionBackwardFilterImpl::Algorithm* | ConvolutionBackwardFilterImpl::Algorithm* | ||||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args(this, src, diff, grad); | AlgoBase::SizeArgs args(this, src, diff, grad); | ||||
| if (is_miopen_supported(args.as_fwd_args())) { | if (is_miopen_supported(args.as_fwd_args())) { | ||||
| auto algo = | auto algo = | ||||
| megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.miopen_algos[0], attr); | |||||
| megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.miopen_algos[0], positive_attr, | |||||
| negative_attr); | |||||
| if (algo) | if (algo) | ||||
| return algo; | return algo; | ||||
| } | } | ||||
| if (args.grad_filter_meta.group > 1 && | if (args.grad_filter_meta.group > 1 && | ||||
| sm_algo_pack.chanwise.is_available_attribute( | sm_algo_pack.chanwise.is_available_attribute( | ||||
| args, attr, workspace_limit_in_bytes)) { | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| // prefer special chanwise impl | // prefer special chanwise impl | ||||
| return &sm_algo_pack.chanwise; | return &sm_algo_pack.chanwise; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv bwd_filter", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv bwd_filter"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<ConvolutionBackwardFilterImpl>( | |||||
| sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, | |||||
| "rocm conv bwd_filter", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | ||||
| @@ -26,9 +26,11 @@ public: | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& src, const CanonizedFilterMeta& filter, | const TensorLayout& src, const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | const TensorLayout& dst, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, filter, dst, | return get_algorithm_heuristic(src, filter, dst, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
| @@ -72,16 +74,17 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
| const TensorLayout& dst) override; | const TensorLayout& dst) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& filter, | |||||
| const TensorLayout& dst, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& filter, | |||||
| const TensorLayout& dst, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
| const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& dst, | const TensorLayout& dst, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -94,9 +97,11 @@ public: | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | const CanonizedFilterMeta& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | const TensorLayout& grad, size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr) { | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(filter, diff, grad, | return get_algorithm_heuristic(filter, diff, grad, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | size_t get_workspace_in_bytes(const TensorLayout& filter, | ||||
| @@ -118,16 +123,17 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& filter, const TensorLayout& diff, | const TensorLayout& filter, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& filter, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& filter, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const TensorLayout& grad, | const TensorLayout& grad, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -137,13 +143,14 @@ public: | |||||
| using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | using ConvolutionBackwardFilter::ConvolutionBackwardFilter; | ||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, | ||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | ||||
| AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const CanonizedFilterMeta& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) { | |||||
| AlgorithmInfo get_algorithm_info_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| return get_algorithm_heuristic(src, diff, grad, | return get_algorithm_heuristic(src, diff, grad, | ||||
| workspace_limit_in_bytes, attr) | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | size_t get_workspace_in_bytes(const TensorLayout& src, | ||||
| @@ -165,16 +172,17 @@ private: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& src, const TensorLayout& diff, | const TensorLayout& src, const TensorLayout& diff, | ||||
| const TensorLayout& grad) override; | const TensorLayout& grad) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& src, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | Algorithm* get_algorithm_heuristic(const TensorLayout& src, | ||||
| const TensorLayout& diff, | const TensorLayout& diff, | ||||
| const CanonizedFilterMeta& grad, | const CanonizedFilterMeta& grad, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr); | |||||
| static AlgoPack sm_algo_pack; | static AlgoPack sm_algo_pack; | ||||
| }; | }; | ||||
| @@ -72,9 +72,12 @@ public: | |||||
| } | } | ||||
| bool is_available_attribute( | bool is_available_attribute( | ||||
| const SizeArgs& args, | const SizeArgs& args, | ||||
| const AlgoAttribute& attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, | |||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, | |||||
| size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
| return contain_attribute(attr) && is_available_wk(args, limit); | |||||
| return contain_attribute_all(positive_attr) && | |||||
| !contain_attribute_any(negative_attr) && | |||||
| is_available_wk(args, limit); | |||||
| } | } | ||||
| AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
| const Workspace& workspace) { | const Workspace& workspace) { | ||||
| @@ -29,21 +29,16 @@ MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, | |||||
| MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | ||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& attr) { | |||||
| size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr) { | |||||
| AlgoBase::SizeArgs args{this, A, B, C}; | AlgoBase::SizeArgs args{this, A, B, C}; | ||||
| if (sm_algo_pack.blas.is_available_attribute(args, attr, | |||||
| workspace_limit_in_bytes)) { | |||||
| if (sm_algo_pack.blas.is_available_attribute( | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.blas; | return &sm_algo_pack.blas; | ||||
| } | } | ||||
| if (attr != AlgoAttribute::DEFAULT) { | |||||
| return megdnn::get_algo_with_attribute<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward", attr); | |||||
| } else { | |||||
| return megdnn::get_usable_algo<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward"); | |||||
| } | |||||
| return megdnn::get_algo_match_attribute<MatrixMulForwardImpl>( | |||||
| sm_algo_pack.all_algos, args, workspace_limit_in_bytes, | |||||
| "matrix mul forward", positive_attr, negative_attr); | |||||
| } | } | ||||
| size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | ||||
| @@ -36,11 +36,11 @@ private: | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override; | const TensorLayout& /*C*/) override; | ||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, | |||||
| size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*attr*/) override; | |||||
| Algorithm* get_algorithm_heuristic( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
| const AlgoAttribute& /*positive_attr*/, | |||||
| const AlgoAttribute& /*negative_attr*/) override; | |||||
| const char* get_algorithm_set_name() const override { | const char* get_algorithm_set_name() const override { | ||||
| return "ROCM MATMUL"; | return "ROCM MATMUL"; | ||||
| @@ -278,27 +278,21 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| AlgoAttribute extract_algo_attribute_from_execution_strategy( | |||||
| //! return pair<positive_attr, negative_attr> | |||||
| std::pair<AlgoAttribute, AlgoAttribute> | |||||
| extract_algo_attribute_from_execution_strategy( | |||||
| const ExecutionStrategy& strategy) { | const ExecutionStrategy& strategy) { | ||||
| AlgoAttribute ret = AlgoAttribute::DEFAULT; | |||||
| std::pair<AlgoAttribute, AlgoAttribute> ret = | |||||
| std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||||
| if (strategy & ExecutionStrategy::REPRODUCIBLE) { | if (strategy & ExecutionStrategy::REPRODUCIBLE) { | ||||
| ret |= AlgoAttribute::REPRODUCIBLE; | |||||
| ret.first |= AlgoAttribute::REPRODUCIBLE; | |||||
| } | } | ||||
| return ret; | |||||
| } | |||||
| //! Test whether the algo attribute of a algo match the require | |||||
| //! algo_strategy | |||||
| static bool algo_attribute_match_strategy(AlgoAttribute attribute, | |||||
| ExecutionStrategy selected_strategy) { | |||||
| bool ret = true; | |||||
| if (selected_strategy & ExecutionStrategy::OPTMIZED) { | |||||
| ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute)); | |||||
| } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
| ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute); | |||||
| if (strategy & ExecutionStrategy::OPTMIZED) { | |||||
| ret.second |= AlgoAttribute::NAIVE; | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -311,7 +305,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
| return; | return; | ||||
| AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
| auto target_attribute = | |||||
| auto target_attr = | |||||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | extract_algo_attribute_from_execution_strategy(selected_strategy); | ||||
| std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | ||||
| double cur_timeout = 0; | double cur_timeout = 0; | ||||
| @@ -332,14 +326,16 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo); | auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo); | ||||
| if (!algo_attribute_match_strategy(palgo->attribute(), | |||||
| selected_strategy)) { | |||||
| if (!(palgo->contain_attribute_all(target_attr.first) && | |||||
| !palgo->contain_attribute_any(target_attr.second))) { | |||||
| mgb_log_debug( | mgb_log_debug( | ||||
| "skip algo %s with attribute%s, which is not match the " | |||||
| "profile strategy required attribute%s.", | |||||
| "skip algo %s with attribute(%s), which is not match the " | |||||
| "profile strategy required contain attribute(%s) and not " | |||||
| "contain attribute(%s).", | |||||
| algo.name.c_str(), | algo.name.c_str(), | ||||
| Algorithm::attribute_str(palgo->attribute()).c_str(), | Algorithm::attribute_str(palgo->attribute()).c_str(), | ||||
| Algorithm::attribute_str(target_attribute).c_str()); | |||||
| Algorithm::attribute_str(target_attr.first).c_str(), | |||||
| Algorithm::attribute_str(target_attr.second).c_str()); | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -370,10 +366,12 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
| rst.workspace, rst.time); | rst.workspace, rst.time); | ||||
| prof_rst.push_back(rst); | prof_rst.push_back(rst); | ||||
| } | } | ||||
| std::string msg = | |||||
| ssprintf("no usable %s algorithm %s with attribute(%s)", | |||||
| ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(), | |||||
| Algorithm::attribute_str(target_attribute).c_str()); | |||||
| std::string msg = ssprintf( | |||||
| "no usable %s algorithm %s with attribute(%s) and without " | |||||
| "attribute(%s)", | |||||
| ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(), | |||||
| Algorithm::attribute_str(target_attr.first).c_str(), | |||||
| Algorithm::attribute_str(target_attr.second).c_str()); | |||||
| mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); | ||||
| FixedTensorLayouts origin_layouts = ctx.layouts(); | FixedTensorLayouts origin_layouts = ctx.layouts(); | ||||
| @@ -460,9 +458,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, | |||||
| Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); | Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); | ||||
| mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
| ret.append("): algo=" + std::string(palgo->name())); | ret.append("): algo=" + std::string(palgo->name())); | ||||
| ret.append(ssprintf(" workspace=%.2fMiB attirbute=%d", | |||||
| ret.append(ssprintf(" workspace=%.2fMiB attirbute(%s)", | |||||
| workspace / (1024 * 1024.0), | workspace / (1024 * 1024.0), | ||||
| static_cast<uint32_t>(palgo->attribute()))); | |||||
| Algorithm::attribute_str(palgo->attribute()).c_str())); | |||||
| mgb_log_debug("%s", ret.c_str()); | mgb_log_debug("%s", ret.c_str()); | ||||
| megdnn_opr->execution_policy() = policy; | megdnn_opr->execution_policy() = policy; | ||||
| @@ -602,13 +600,14 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
| } | } | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| auto attr = | |||||
| extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
| ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
| policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
| args..., workspace_limit, | |||||
| extract_algo_attribute_from_execution_strategy( | |||||
| selected_strategy)), | |||||
| m_layouts) | |||||
| .desc; | |||||
| policy.algo = | |||||
| APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
| args..., workspace_limit, attr.first, attr.second), | |||||
| m_layouts) | |||||
| .desc; | |||||
| Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); | ||||
| mgb_assert(algo, "Unknown algo description"); | mgb_assert(algo, "Unknown algo description"); | ||||
| @@ -666,13 +665,14 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
| } else { | } else { | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
| policy.algo = | |||||
| APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
| args..., workspace_limit, | |||||
| extract_algo_attribute_from_execution_strategy( | |||||
| selected_strategy)), | |||||
| m_layouts) | |||||
| .desc; | |||||
| auto attr = extract_algo_attribute_from_execution_strategy( | |||||
| selected_strategy); | |||||
| policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||||
| args..., workspace_limit, attr.first, | |||||
| attr.second), | |||||
| m_layouts) | |||||
| .desc; | |||||
| } | } | ||||
| mgb_assert(policy.algo.valid(), | mgb_assert(policy.algo.valid(), | ||||
| "No algo found from cache or heuristic, maybe some error " | "No algo found from cache or heuristic, maybe some error " | ||||
| @@ -2189,7 +2189,7 @@ TEST(TestOprDNN, HeuristicReproducible) { | |||||
| megdnn_opr->get_algorithm_from_desc(algo); | megdnn_opr->get_algorithm_from_desc(algo); | ||||
| mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
| if (strategy == S(S::HEURISTIC | S::REPRODUCIBLE)) { | if (strategy == S(S::HEURISTIC | S::REPRODUCIBLE)) { | ||||
| EXPECT_TRUE(palgo->contain_attribute( | |||||
| EXPECT_TRUE(palgo->contain_attribute_all( | |||||
| megdnn::AlgoAttribute::REPRODUCIBLE)); | megdnn::AlgoAttribute::REPRODUCIBLE)); | ||||
| } | } | ||||
| algo_name0 = palgo->name(); | algo_name0 = palgo->name(); | ||||
| @@ -2371,21 +2371,23 @@ public: | |||||
| std::vector<AlgorithmInfo>(const TensorLayout& p0, | std::vector<AlgorithmInfo>(const TensorLayout& p0, | ||||
| const TensorLayout& p1, | const TensorLayout& p1, | ||||
| const TensorLayout& p2)); | const TensorLayout& p2)); | ||||
| MOCK_METHOD5(get_algorithm_info_heuristic, | |||||
| MOCK_METHOD6(get_algorithm_info_heuristic, | |||||
| AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | AlgorithmInfo(const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& attr)); | |||||
| const TensorLayout& p2, | |||||
| size_t workspace_limit_in_bytes, | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr)); | |||||
| MOCK_METHOD3(get_all_algorithms, | MOCK_METHOD3(get_all_algorithms, | ||||
| std::vector<Algorithm*>(const TensorLayout& p0, | std::vector<Algorithm*>(const TensorLayout& p0, | ||||
| const TensorLayout& p1, | const TensorLayout& p1, | ||||
| const TensorLayout& p2)); | const TensorLayout& p2)); | ||||
| MOCK_METHOD5(get_algorithm_heuristic, | |||||
| MOCK_METHOD6(get_algorithm_heuristic, | |||||
| Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | Algorithm*(const TensorLayout& p0, const TensorLayout& p1, | ||||
| const TensorLayout& p2, | const TensorLayout& p2, | ||||
| size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
| const AlgoAttribute& attr)); | |||||
| const AlgoAttribute& positive_attr, | |||||
| const AlgoAttribute& negative_attr)); | |||||
| MOCK_METHOD1(get_algorithm_from_desc, | MOCK_METHOD1(get_algorithm_from_desc, | ||||
| Algorithm*(const AlgorithmDesc&)); | Algorithm*(const AlgorithmDesc&)); | ||||
| @@ -2468,7 +2470,7 @@ TEST_F(TestWeightPreprocess, NoPreprocessNeeded) { | |||||
| auto& mock = mock_conv(); | auto& mock = mock_conv(); | ||||
| MockAlgorithm algo; | MockAlgorithm algo; | ||||
| EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _)) | |||||
| EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _, _)) | |||||
| .WillRepeatedly(Return(&algo)); | .WillRepeatedly(Return(&algo)); | ||||
| EXPECT_CALL(mock, get_algorithm_from_desc(_)) | EXPECT_CALL(mock, get_algorithm_from_desc(_)) | ||||
| .WillRepeatedly(Return(&algo)); | .WillRepeatedly(Return(&algo)); | ||||
| @@ -2508,7 +2510,7 @@ TEST_F(TestWeightPreprocess, PreprocessCalledOnlyOnce) { | |||||
| .WillRepeatedly(Return(&algo)); | .WillRepeatedly(Return(&algo)); | ||||
| Expectation algo_call = | Expectation algo_call = | ||||
| EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _)) | |||||
| EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _, _)) | |||||
| .WillOnce(Return(&algo)); | .WillOnce(Return(&algo)); | ||||
| Expectation ws_call = EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _)) | Expectation ws_call = EXPECT_CALL(mock, get_workspace_in_bytes(_, _, _, _)) | ||||
| .After(algo_call) | .After(algo_call) | ||||
| @@ -2567,7 +2569,7 @@ TEST_F(TestNoWeightPreprocess, NoPreprocess) { | |||||
| auto& mock = mock_conv(); | auto& mock = mock_conv(); | ||||
| MockAlgorithm algo; | MockAlgorithm algo; | ||||
| EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _)) | |||||
| EXPECT_CALL(mock, get_algorithm_heuristic(_, _, _, _, _, _)) | |||||
| .WillRepeatedly(Return(&algo)); | .WillRepeatedly(Return(&algo)); | ||||
| EXPECT_CALL(mock, get_algorithm_from_desc(_)) | EXPECT_CALL(mock, get_algorithm_from_desc(_)) | ||||
| .WillRepeatedly(Return(&algo)); | .WillRepeatedly(Return(&algo)); | ||||