GitOrigin-RevId: 2409b6ba16
tags/v1.3.0
| @@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||||
| } | } | ||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| std::vector<BatchedMatrixMulForward::Algorithm*> | |||||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) { | |||||
| return {static_cast<HandleImpl*>(handle()) | |||||
| ->default_batched_matmul_fwd_algo()}; | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| BatchedMatrixMulForward::Algorithm* | |||||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
| bool /* reproducible */) { | |||||
| return static_cast<HandleImpl*>(handle()) | |||||
| ->default_batched_matmul_fwd_algo(); | |||||
| } | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| @@ -25,17 +26,13 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override { | |||||
| return {}; | |||||
| } | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | ||||
| const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
| size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
| bool /* reproducible */) override { | |||||
| return nullptr; | |||||
| } | |||||
| bool /* reproducible */) override; | |||||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||
| @@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm | |||||
| DefaultLocalShareBackwardFilterAlgorithm | DefaultLocalShareBackwardFilterAlgorithm | ||||
| HandleImpl::m_default_local_share_bwd_filter_algo; | HandleImpl::m_default_local_share_bwd_filter_algo; | ||||
| DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo; | |||||
| DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo; | |||||
| HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | ||||
| HandleType type) | HandleType type) | ||||
| : HandleImplHelper(computing_handle, type), | : HandleImplHelper(computing_handle, type), | ||||
| @@ -13,6 +13,7 @@ | |||||
| #include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
| #include "src/naive/convolution/algorithms.h" | #include "src/naive/convolution/algorithms.h" | ||||
| #include "src/naive/matrix_mul/algorithms.h" | |||||
| #include "src/naive/local_share/algorithms.h" | #include "src/naive/local_share/algorithms.h" | ||||
| #include "src/naive/convolution3d/algorithms.h" | #include "src/naive/convolution3d/algorithms.h" | ||||
| @@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper { | |||||
| static DefaultLocalShareBackwardFilterAlgorithm | static DefaultLocalShareBackwardFilterAlgorithm | ||||
| m_default_local_share_bwd_filter_algo; | m_default_local_share_bwd_filter_algo; | ||||
| static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo; | |||||
| static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo; | |||||
| //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | ||||
| template <typename T> | template <typename T> | ||||
| void move_kern_func_to_new_kern_and_dispatch(T& func) { | void move_kern_func_to_new_kern_and_dispatch(T& func) { | ||||
| @@ -109,6 +113,14 @@ public: | |||||
| return &m_default_local_share_bwd_filter_algo; | return &m_default_local_share_bwd_filter_algo; | ||||
| } | } | ||||
| MatrixMulForward::Algorithm* default_matmul_fwd_algo() { | |||||
| return &m_default_matmul_fwd_algo; | |||||
| } | |||||
| BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() { | |||||
| return &m_default_batched_matmul_fwd_algo; | |||||
| } | |||||
| Relayout* relayout_opr() override { | Relayout* relayout_opr() override { | ||||
| return get_helper_opr<Relayout, 2>(this); | return get_helper_opr<Relayout, 2>(this); | ||||
| } | } | ||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/matrix_mul/algorithms.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs/linalg.h" | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| class DefaultMatrixMulAlgorithm final | |||||
| : public megdnn::MatrixMulForward::Algorithm { | |||||
| bool is_reproducible() const override { return true; } | |||||
| const char* name() const override { return "DEFAULT"; } | |||||
| uint32_t type() const override { return 0; } | |||||
| }; | |||||
| class DefaultBatchedMatrixMulAlgorithm final | |||||
| : public megdnn::BatchedMatrixMulForward::Algorithm { | |||||
| bool is_reproducible() const override { return true; } | |||||
| const char* name() const override { return "DEFAULT"; } | |||||
| uint32_t type() const override { return 0; } | |||||
| }; | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| } | } | ||||
| std::vector<MatrixMulForward::Algorithm*> | |||||
| MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||||
| const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/) { | |||||
| return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||||
| } | |||||
| MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||||
| bool /* reproducible */) { | |||||
| return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||||
| } | |||||
| } // namespace naive | } // namespace naive | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| @@ -26,17 +27,13 @@ public: | |||||
| std::vector<Algorithm*> get_all_algorithms( | std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | const TensorLayout& /*A*/, const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/) override { | |||||
| return {}; | |||||
| } | |||||
| const TensorLayout& /*C*/) override; | |||||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | ||||
| const TensorLayout& /*B*/, | const TensorLayout& /*B*/, | ||||
| const TensorLayout& /*C*/, | const TensorLayout& /*C*/, | ||||
| size_t /*workspace_limit_in_bytes*/, | size_t /*workspace_limit_in_bytes*/, | ||||
| bool /* reproducible */) override { | |||||
| return nullptr; | |||||
| } | |||||
| bool /* reproducible */) override; | |||||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | const char* get_algorithm_set_name() const override { return "DEFAULT"; } | ||||