GitOrigin-RevId: 2409b6ba16
tags/v1.3.0
| @@ -64,9 +64,24 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| std::vector<BatchedMatrixMulForward::Algorithm*> | |||
| BatchedMatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) { | |||
| return {static_cast<HandleImpl*>(handle()) | |||
| ->default_batched_matmul_fwd_algo()}; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| BatchedMatrixMulForward::Algorithm* | |||
| BatchedMatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||
| bool /* reproducible */) { | |||
| return static_cast<HandleImpl*>(handle()) | |||
| ->default_batched_matmul_fwd_algo(); | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| @@ -25,17 +26,13 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override { | |||
| return {}; | |||
| } | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /* reproducible */) override { | |||
| return nullptr; | |||
| } | |||
| bool /* reproducible */) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||
| @@ -106,6 +106,9 @@ DefaultLocalShareBackwardDataAlgorithm | |||
| DefaultLocalShareBackwardFilterAlgorithm | |||
| HandleImpl::m_default_local_share_bwd_filter_algo; | |||
| DefaultMatrixMulAlgorithm HandleImpl::m_default_matmul_fwd_algo; | |||
| DefaultBatchedMatrixMulAlgorithm HandleImpl::m_default_batched_matmul_fwd_algo; | |||
| HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, | |||
| HandleType type) | |||
| : HandleImplHelper(computing_handle, type), | |||
| @@ -13,6 +13,7 @@ | |||
| #include "src/common/handle_impl.h" | |||
| #include "src/naive/convolution/algorithms.h" | |||
| #include "src/naive/matrix_mul/algorithms.h" | |||
| #include "src/naive/local_share/algorithms.h" | |||
| #include "src/naive/convolution3d/algorithms.h" | |||
| @@ -46,6 +47,9 @@ class HandleImpl : public HandleImplHelper { | |||
| static DefaultLocalShareBackwardFilterAlgorithm | |||
| m_default_local_share_bwd_filter_algo; | |||
| static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo; | |||
| static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo; | |||
| //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch | |||
| template <typename T> | |||
| void move_kern_func_to_new_kern_and_dispatch(T& func) { | |||
| @@ -109,6 +113,14 @@ public: | |||
| return &m_default_local_share_bwd_filter_algo; | |||
| } | |||
| MatrixMulForward::Algorithm* default_matmul_fwd_algo() { | |||
| return &m_default_matmul_fwd_algo; | |||
| } | |||
| BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() { | |||
| return &m_default_batched_matmul_fwd_algo; | |||
| } | |||
| Relayout* relayout_opr() override { | |||
| return get_helper_opr<Relayout, 2>(this); | |||
| } | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * \file dnn/src/naive/matrix_mul/algorithms.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs/linalg.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class DefaultMatrixMulAlgorithm final | |||
| : public megdnn::MatrixMulForward::Algorithm { | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "DEFAULT"; } | |||
| uint32_t type() const override { return 0; } | |||
| }; | |||
| class DefaultBatchedMatrixMulAlgorithm final | |||
| : public megdnn::BatchedMatrixMulForward::Algorithm { | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "DEFAULT"; } | |||
| uint32_t type() const override { return 0; } | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -81,6 +81,20 @@ void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| MIDOUT_END(); | |||
| } | |||
| std::vector<MatrixMulForward::Algorithm*> | |||
| MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) { | |||
| return {static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo()}; | |||
| } | |||
| MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, | |||
| bool /* reproducible */) { | |||
| return static_cast<HandleImpl*>(handle())->default_matmul_fwd_algo(); | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| @@ -26,17 +27,13 @@ public: | |||
| std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/) override { | |||
| return {}; | |||
| } | |||
| const TensorLayout& /*C*/) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||
| const TensorLayout& /*B*/, | |||
| const TensorLayout& /*C*/, | |||
| size_t /*workspace_limit_in_bytes*/, | |||
| bool /* reproducible */) override { | |||
| return nullptr; | |||
| } | |||
| bool /* reproducible */) override; | |||
| const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||