Browse Source

refactor(dnn): refactor algo interface, use algoinfo instead of global algorithm

GitOrigin-RevId: 479718ac75
tags/v1.2.0
Megvii Engine Team 5 years ago
parent
commit
a1877ee0fa
100 changed files with 2847 additions and 1559 deletions
  1. +198
    -20
      dnn/include/megdnn/oprs/base.h
  2. +1
    -0
      dnn/src/aarch64/conv_bias/fp16/algos.h
  3. +1
    -0
      dnn/src/aarch64/conv_bias/fp32/algos.h
  4. +1
    -0
      dnn/src/aarch64/conv_bias/int8/algos.h
  5. +39
    -13
      dnn/src/aarch64/conv_bias/opr_impl.cpp
  6. +4
    -1
      dnn/src/aarch64/conv_bias/opr_impl.h
  7. +1
    -0
      dnn/src/aarch64/conv_bias/quint8/algos.h
  8. +28
    -1
      dnn/src/aarch64/matrix_mul/algos.h
  9. +45
    -28
      dnn/src/aarch64/matrix_mul/opr_impl.cpp
  10. +6
    -1
      dnn/src/aarch64/matrix_mul/opr_impl.h
  11. +6
    -1
      dnn/src/arm_common/conv_bias/f16/algos.h
  12. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/algos.h
  13. +13
    -0
      dnn/src/arm_common/conv_bias/int8/algos.h
  14. +11
    -2
      dnn/src/arm_common/conv_bias/int8x8x16/algos.h
  15. +80
    -53
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  16. +7
    -2
      dnn/src/arm_common/conv_bias/opr_impl.h
  17. +4
    -0
      dnn/src/arm_common/conv_bias/quint8/algos.h
  18. +2
    -0
      dnn/src/arm_common/convolution/int8x8x32/algos.h
  19. +4
    -0
      dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp
  20. +4
    -0
      dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp
  21. +41
    -31
      dnn/src/arm_common/convolution/opr_impl.cpp
  22. +8
    -5
      dnn/src/arm_common/convolution/opr_impl.h
  23. +2
    -0
      dnn/src/arm_common/convolution/quint8/algos.h
  24. +3
    -0
      dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp
  25. +3
    -0
      dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp
  26. +1
    -0
      dnn/src/arm_common/elemwise/opr_impl.h
  27. +7
    -0
      dnn/src/arm_common/matrix_mul/algos.h
  28. +31
    -12
      dnn/src/arm_common/matrix_mul/opr_impl.cpp
  29. +8
    -1
      dnn/src/arm_common/matrix_mul/opr_impl.h
  30. +3
    -0
      dnn/src/arm_common/pooling/opr_impl.h
  31. +1
    -0
      dnn/src/armv7/conv_bias/int8/algos.h
  32. +26
    -8
      dnn/src/armv7/conv_bias/opr_impl.cpp
  33. +4
    -1
      dnn/src/armv7/conv_bias/opr_impl.h
  34. +1
    -0
      dnn/src/armv7/conv_bias/quint8/algos.h
  35. +24
    -1
      dnn/src/armv7/matrix_mul/algos.h
  36. +42
    -24
      dnn/src/armv7/matrix_mul/opr_impl.cpp
  37. +7
    -1
      dnn/src/armv7/matrix_mul/opr_impl.h
  38. +101
    -0
      dnn/src/common/algo_base.h
  39. +25
    -6
      dnn/src/common/algo_chooser.h
  40. +33
    -0
      dnn/src/common/utils.h
  41. +6
    -0
      dnn/src/cuda/batch_conv_bias/algo.cpp
  42. +17
    -4
      dnn/src/cuda/batch_conv_bias/algo.h
  43. +12
    -9
      dnn/src/cuda/batch_conv_bias/opr_impl.h
  44. +8
    -0
      dnn/src/cuda/batched_matrix_mul/algo.cpp
  45. +26
    -3
      dnn/src/cuda/batched_matrix_mul/algo.h
  46. +3
    -3
      dnn/src/cuda/batched_matrix_mul/brute_force.cpp
  47. +10
    -6
      dnn/src/cuda/batched_matrix_mul/opr_impl.h
  48. +10
    -37
      dnn/src/cuda/conv_bias/algo.cpp
  49. +160
    -31
      dnn/src/cuda/conv_bias/algo.h
  50. +2
    -2
      dnn/src/cuda/conv_bias/bfloat16.cpp
  51. +2
    -2
      dnn/src/cuda/conv_bias/opr_impl.cpp
  52. +14
    -11
      dnn/src/cuda/conv_bias/opr_impl.h
  53. +6
    -0
      dnn/src/cuda/convolution/backward_data/algo.cpp
  54. +163
    -157
      dnn/src/cuda/convolution/backward_data/algo.h
  55. +2
    -2
      dnn/src/cuda/convolution/backward_data/bfloat16.cpp
  56. +3
    -29
      dnn/src/cuda/convolution/backward_data/cudnn.cpp
  57. +6
    -0
      dnn/src/cuda/convolution/backward_filter/algo.cpp
  58. +155
    -147
      dnn/src/cuda/convolution/backward_filter/algo.h
  59. +2
    -2
      dnn/src/cuda/convolution/backward_filter/bfloat16.cpp
  60. +3
    -29
      dnn/src/cuda/convolution/backward_filter/cudnn.cpp
  61. +1
    -10
      dnn/src/cuda/convolution/opr_impl.cpp
  62. +114
    -92
      dnn/src/cuda/convolution/opr_impl.h
  63. +7
    -1
      dnn/src/cuda/convolution3d/backward_data/algo.cpp
  64. +134
    -127
      dnn/src/cuda/convolution3d/backward_data/algo.h
  65. +3
    -21
      dnn/src/cuda/convolution3d/backward_data/cudnn.cpp
  66. +9
    -3
      dnn/src/cuda/convolution3d/backward_filter/algo.cpp
  67. +140
    -140
      dnn/src/cuda/convolution3d/backward_filter/algo.h
  68. +3
    -23
      dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp
  69. +11
    -5
      dnn/src/cuda/convolution3d/forward/algo.cpp
  70. +148
    -151
      dnn/src/cuda/convolution3d/forward/algo.h
  71. +3
    -23
      dnn/src/cuda/convolution3d/forward/cudnn.cpp
  72. +147
    -117
      dnn/src/cuda/convolution3d/opr_impl.h
  73. +131
    -0
      dnn/src/cuda/cudnn_wrapper.cpp
  74. +47
    -3
      dnn/src/cuda/cudnn_wrapper.h
  75. +5
    -0
      dnn/src/cuda/deformable_conv/bwd_data/algo.cpp
  76. +13
    -4
      dnn/src/cuda/deformable_conv/bwd_data/algo.h
  77. +4
    -0
      dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp
  78. +13
    -4
      dnn/src/cuda/deformable_conv/bwd_flt/algo.h
  79. +6
    -0
      dnn/src/cuda/deformable_conv/fwd/algo.cpp
  80. +13
    -4
      dnn/src/cuda/deformable_conv/fwd/algo.h
  81. +65
    -50
      dnn/src/cuda/deformable_conv/opr_impl.h
  82. +6
    -0
      dnn/src/cuda/local_share/backward_data/algo.cpp
  83. +16
    -3
      dnn/src/cuda/local_share/backward_data/algo.h
  84. +6
    -0
      dnn/src/cuda/local_share/backward_filter/algo.cpp
  85. +16
    -3
      dnn/src/cuda/local_share/backward_filter/algo.h
  86. +6
    -0
      dnn/src/cuda/local_share/forward/algo.cpp
  87. +17
    -3
      dnn/src/cuda/local_share/forward/algo.h
  88. +32
    -24
      dnn/src/cuda/local_share/opr_impl.h
  89. +8
    -0
      dnn/src/cuda/matrix_mul/algos.cpp
  90. +38
    -26
      dnn/src/cuda/matrix_mul/algos.h
  91. +1
    -1
      dnn/src/cuda/matrix_mul/bfloat16.cpp
  92. +11
    -9
      dnn/src/cuda/matrix_mul/opr_impl.h
  93. +5
    -0
      dnn/src/fallback/conv_bias/algos.cpp
  94. +25
    -0
      dnn/src/fallback/conv_bias/algos.h
  95. +6
    -0
      dnn/src/fallback/conv_bias/common.h
  96. +7
    -0
      dnn/src/fallback/conv_bias/conv1x1/algos.h
  97. +1
    -0
      dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h
  98. +8
    -0
      dnn/src/fallback/conv_bias/im2col/algos.h
  99. +84
    -25
      dnn/src/fallback/conv_bias/opr_impl.cpp
  100. +87
    -1
      dnn/src/fallback/conv_bias/opr_impl.h

+ 198
- 20
dnn/include/megdnn/oprs/base.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once


@@ -92,24 +93,72 @@ enum class AlgoDataType : uint32_t {
/*! /*!
* \brief Abstract representation of an algorithm for implementing * \brief Abstract representation of an algorithm for implementing
* the operator * the operator
*
* All pointers to Algorithm should be allocated globally and usable
* across multiple megdnn handles, and they should not be freed by
* the caller.
*/ */
class Algorithm { class Algorithm {
public: public:
static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1);
/**
* \brief Algorithm information, we can get real algo from
* AlgorithmInfo::Info::Desc
*/
struct Info {
struct Desc {
//! backend of the algo belonging to
Handle::HandleType handle_type;
//! indicate the real algo implementation
uint32_t type = INVALID_ALGO_TYPE;
//! serialized param of the algo type
std::string param;
bool valid() const { return type != INVALID_ALGO_TYPE; }
void reset() { type = INVALID_ALGO_TYPE; }

bool operator==(const Desc& rhs) const {
return handle_type == rhs.handle_type && type == rhs.type &&
param == rhs.param;
}
} desc;
//! algorithm name
std::string name;
bool is_reproducible;
bool valid() const { return desc.valid(); }
void reset() { desc.reset(); }
//! desc donate the algo
bool operator==(const Info& rhs) const { return desc == rhs.desc; }
};

virtual ~Algorithm() = default;

/** /**
* \brief whether the execution result is * \brief whether the execution result is
* reproducible across multiple runs. * reproducible across multiple runs.
*/ */
virtual bool is_reproducible() const = 0; virtual bool is_reproducible() const = 0;
virtual const char* name() const = 0; virtual const char* name() const = 0;
//! serialized param
virtual std::string param() const { return {}; }
virtual uint32_t type() const = 0;


Handle::HandleType handle_type() const { return m_handle_type; } Handle::HandleType handle_type() const { return m_handle_type; }
Info info() const {
return {{handle_type(), type(), param()}, name(), is_reproducible()};
}

template <typename T>
static void serialize_write_pod(const T& val, std::string& result) {
result.append(reinterpret_cast<const char*>(&val), sizeof(T));
}

static void serialize_write_pod(const char* val, std::string& result) {
result.append(val, strlen(val));
}

template <typename T>
static T deserialize_read_pod(const std::string& data, size_t offset = 0) {
T ret = *reinterpret_cast<const T*>(&data[offset]);
return ret;
}


protected: protected:
~Algorithm() = default;
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; Handle::HandleType m_handle_type = Handle::HandleType::NAIVE;
}; };


@@ -127,6 +176,8 @@ class MultiAlgoOpr;
template <class Opr> template <class Opr>
class MultiAlgoOpr<Opr, -1> { class MultiAlgoOpr<Opr, -1> {
public: public:
using AlgorithmInfo = detail::Algorithm::Info;
using AlgorithmDesc = detail::Algorithm::Info::Desc;
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
/*! /*!
* \brief get a string representation for current algorithm set; * \brief get a string representation for current algorithm set;
@@ -139,8 +190,8 @@ public:


//! policy for executing the operator //! policy for executing the operator
struct ExecutionPolicy { struct ExecutionPolicy {
//! nullptr means using heuristic
Algorithm* algorithm = nullptr;
//! INVALID_ALGO_TYPE algo_type means using heuristic
AlgorithmInfo algo;
}; };


ExecutionPolicy& execution_policy() { return m_execution_policy; } ExecutionPolicy& execution_policy() { return m_execution_policy; }
@@ -161,6 +212,39 @@ template <class Opr>
class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2)) {
ret.emplace_back(algo->info());
}
return ret;
}

/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes,
reproducible)
->info();
}

protected:
~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
@@ -179,9 +263,6 @@ public:
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; bool reproducible = false) = 0;

protected:
~MultiAlgoOpr() = default;
}; };


//! specializae for nargs == 4 //! specializae for nargs == 4
@@ -189,6 +270,40 @@ template <class Opr>
class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) {
ret.emplace_back(algo->info());
}
return ret;
}

/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes,
reproducible)
->info();
}

protected:
~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
@@ -207,9 +322,6 @@ public:
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; bool reproducible = false) = 0;

protected:
~MultiAlgoOpr() = default;
}; };


//! specializae for nargs == 5 //! specializae for nargs == 5
@@ -217,6 +329,42 @@ template <class Opr>
class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
const TensorLayout& p1,
const TensorLayout& p2,
const TensorLayout& p3,
const TensorLayout& p4) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) {
ret.emplace_back(algo->info());
}
return ret;
}

/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
* \p workspace_limit_in_bytes.
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4,
workspace_limit_in_bytes, reproducible)
->info();
}

protected:
~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
@@ -237,9 +385,6 @@ public:
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; bool reproducible = false) = 0;

protected:
~MultiAlgoOpr() = default;
}; };


//! specializae for nargs == 8 //! specializae for nargs == 8
@@ -247,6 +392,42 @@ template <class Opr>
class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> {
public: public:
using Algorithm = detail::Algorithm; using Algorithm = detail::Algorithm;
using AlgorithmInfo = detail::Algorithm::Info;

//! get all possible algorithm decriptions for the specified layouts
std::vector<AlgorithmInfo> get_all_algorithms_info(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7) {
std::vector<AlgorithmInfo> ret;
for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) {
ret.emplace_back(algo->info());
}
return ret;
}

/**
* \brief Returns the best algorithm information which indicate the
* algorithm by heuristic.
*
* The selected algorithm should not use workspace more than
*/
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& p0, const TensorLayout& p1,
const TensorLayout& p2, const TensorLayout& p3,
const TensorLayout& p4, const TensorLayout& p5,
const TensorLayout& p6, const TensorLayout& p7,
size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(),
bool reproducible = false) {
return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7,
workspace_limit_in_bytes, reproducible)
->info();
}

protected:
~MultiAlgoOpr() = default;


//! get all possible algorithms for the specified layouts //! get all possible algorithms for the specified layouts
virtual std::vector<Algorithm*> get_all_algorithms( virtual std::vector<Algorithm*> get_all_algorithms(
@@ -269,9 +450,6 @@ public:
size_t workspace_limit_in_bytes = size_t workspace_limit_in_bytes =
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
bool reproducible = false) = 0; bool reproducible = false) = 0;

protected:
~MultiAlgoOpr() = default;
}; };
} // namespace detail } // namespace detail
} // namespace megdnn } // namespace megdnn


+ 1
- 0
dnn/src/aarch64/conv_bias/fp16/algos.h View File

@@ -31,6 +31,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP16)
}; };
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn


+ 1
- 0
dnn/src/aarch64/conv_bias/fp32/algos.h View File

@@ -36,6 +36,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(AARCH64_DIRECT_STRD2_FP32)
}; };


} // namespace aarch64 } // namespace aarch64


+ 1
- 0
dnn/src/aarch64/conv_bias/int8/algos.h View File

@@ -48,6 +48,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
} }
MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_S8)
}; };


} // namespace aarch64 } // namespace aarch64


+ 39
- 13
dnn/src/aarch64/conv_bias/opr_impl.cpp View File

@@ -32,28 +32,54 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoF16DirectStride2 f16_direct_stride2; AlgoF16DirectStride2 f16_direct_stride2;
#endif #endif


fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_matmul_algos;

public: public:
AlgoPack() { AlgoPack() {
matmul_algos.emplace_back(&qu8_matrix_mul);
matmul_algos.emplace_back(&s8_matrix_mul);
m_matmul_algos.emplace_back(&qu8_matrix_mul);
m_matmul_algos.emplace_back(&s8_matrix_mul);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride2);
m_direct_algos.emplace_back(&f16_direct_stride2);
#endif #endif
direct_algos.emplace_back(&f32_direct_stride2);
m_direct_algos.emplace_back(&f32_direct_stride2);

for (auto&& algo : m_direct_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
for (auto&& algo : m_matmul_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}

const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const {
return m_direct_algos;
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos()
const {
return m_matmul_algos;
} }
SmallVector<AlgoBase*> direct_algos;
SmallVector<AlgoBase*> matmul_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end());
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)

SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
//! We put matmul algos at the begin. Because matmul will get privilege when //! We put matmul algos at the begin. Because matmul will get privilege when
//! prefer return true. See //! prefer return true. See
algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(),
sl_algo_pack.matmul_algos.end());
algos.insert(algos.begin(), algo_pack().matmul_algos().begin(),
algo_pack().matmul_algos().end());
return std::move(algos); return std::move(algos);
} }




+ 4
- 1
dnn/src/aarch64/conv_bias/opr_impl.h View File

@@ -25,7 +25,9 @@ public:
} }
}; };


SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override;

MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl);


protected: protected:
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
@@ -38,6 +40,7 @@ private:
class AlgoF16DirectStride2; class AlgoF16DirectStride2;
#endif #endif
class AlgoPack; class AlgoPack;
static const AlgoPack& algo_pack();
}; };


} // namespace aarch64 } // namespace aarch64


+ 1
- 0
dnn/src/aarch64/conv_bias/quint8/algos.h View File

@@ -48,6 +48,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
} }
MEGDNN_DECL_ALGO_TYPE(AARCH64_MATMUL_QU8)
}; };
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn


+ 28
- 1
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -27,6 +27,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K8X12X1)
}; };


class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
@@ -37,6 +38,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_K8X12X1)
}; };


class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase {
@@ -47,6 +49,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_K4X16X1)
}; };


class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase {
@@ -58,10 +61,17 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16)
}; };


class MatrixMulImpl::AlgoF32Gemv final class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {};
: public arm_common::MatrixMulImpl::AlgoF32Gemv {
public:
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() {
m_handle_type = Handle::HandleType::AARCH64;
}
MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_GEMV)
};


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase {
@@ -72,6 +82,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_K8X24X1)
}; };


class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase {
@@ -83,6 +94,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_8X8)
}; };


#endif #endif
@@ -98,6 +110,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X12X4_DOTPROD)
}; };


class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase {
@@ -110,6 +123,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD)
}; };
#else #else


@@ -124,6 +138,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_4X4X16)
}; };


class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase {
@@ -136,6 +151,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K4X4X16)
}; };


class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase {
@@ -147,6 +163,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8)
}; };
#endif #endif


@@ -160,6 +177,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K8X8X8)
}; };


class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase {
@@ -171,6 +189,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_K4X4X16)
}; };


class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
@@ -186,6 +205,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_16X12X4)
}; };


class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
@@ -201,6 +221,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_K8X8X8)
}; };


class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
@@ -214,6 +235,7 @@ public:
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X16_MK4_4X4X8)
}; };


class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
@@ -225,6 +247,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_K12X8X1)
}; };


class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase {
@@ -236,6 +259,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8)
}; };


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -249,6 +273,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD)
}; };


class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
@@ -262,6 +287,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD)
}; };
#else #else


@@ -273,6 +299,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8)
}; };
#endif #endif




+ 45
- 28
dnn/src/aarch64/matrix_mul/opr_impl.cpp View File

@@ -51,49 +51,66 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoQuint8K8x8x8 quint8_k8x8x8; AlgoQuint8K8x8x8 quint8_k8x8x8;
#endif #endif


SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
public: public:
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;


AlgoPack() { AlgoPack() {
all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32K8x12x1);
all_algos.emplace_back(&f32_mk4_8x12x1);
all_algos.emplace_back(&f32k4x16x1);
all_algos.emplace_back(&f32mk4_4x16);
m_all_algos.emplace_back(&f32_gemv);
m_all_algos.emplace_back(&f32K8x12x1);
m_all_algos.emplace_back(&f32_mk4_8x12x1);
m_all_algos.emplace_back(&f32k4x16x1);
m_all_algos.emplace_back(&f32mk4_4x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16_k8x24x1);
all_algos.emplace_back(&f16_mk8_8x8);
m_all_algos.emplace_back(&f16_k8x24x1);
m_all_algos.emplace_back(&f16_mk8_8x8);
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
#else #else
all_algos.emplace_back(&int8x8x32_k4x4x16);
all_algos.emplace_back(&int8x8x32_k8x8x8);
all_algos.emplace_back(&int8x8x32_mk4_4x4x16);
m_all_algos.emplace_back(&int8x8x32_k4x4x16);
m_all_algos.emplace_back(&int8x8x32_k8x8x8);
m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16);
#endif #endif
all_algos.emplace_back(&int8x8x16_k4x4x16);
all_algos.emplace_back(&int8x8x16_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
all_algos.emplace_back(&int8x8x16_mk4_16x12x4);
m_all_algos.emplace_back(&int8x8x16_k4x4x16);
m_all_algos.emplace_back(&int8x8x16_k8x8x8);
m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8);
m_all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
m_all_algos.emplace_back(&int8x8x16_mk4_16x12x4);


all_algos.emplace_back(&int16x16x32_k12x8x1);
all_algos.emplace_back(&int16x16x32_mk8_8x8);
m_all_algos.emplace_back(&int16x16x32_k12x8x1);
m_all_algos.emplace_back(&int16x16x32_mk8_8x8);
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&quint8_gemv_dotprod);
all_algos.emplace_back(&quint8_k8x8x4_dotprod);
m_all_algos.emplace_back(&quint8_gemv_dotprod);
m_all_algos.emplace_back(&quint8_k8x8x4_dotprod);
#else #else
all_algos.emplace_back(&quint8_k8x8x8);
m_all_algos.emplace_back(&quint8_k8x8x8);
#endif #endif

for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}

const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const {
return m_all_algos;
} }
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack;
auto&& algos = arm_common::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),
s_algo_pack.all_algos.end());
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl)

SmallVector<fallback::MatrixMulImpl::AlgoBase*>
MatrixMulImpl::get_all_packed_algo() {
auto&& algos = arm_common::MatrixMulImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos); return std::move(algos);
} }




+ 6
- 1
dnn/src/aarch64/matrix_mul/opr_impl.h View File

@@ -25,7 +25,10 @@ public:
} }
}; };


SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo()
override;

MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);


private: private:
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
@@ -66,6 +69,8 @@ private:
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16


class AlgoPack; class AlgoPack;
public:
static const AlgoPack& algo_pack();
}; };


} // namespace aarch64 } // namespace aarch64


+ 6
- 1
dnn/src/arm_common/conv_bias/f16/algos.h View File

@@ -30,6 +30,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16)
}; };


class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase {
@@ -45,7 +46,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16)
}; };
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase {
public: public:
@@ -61,6 +62,7 @@ public:
} }


MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16)
}; };
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase {
public: public:
@@ -75,6 +77,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16)
}; };


class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { class ConvBiasImpl::AlgoF16Direct final : public AlgoBase {
@@ -94,6 +97,7 @@ public:
ConvAlgoTypePack get_algo_type() const override{ ConvAlgoTypePack get_algo_type() const override{
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16)
}; };


class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase {
@@ -110,6 +114,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16)
}; };


} // namespace arm_common } // namespace arm_common


+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/algos.h View File

@@ -30,6 +30,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32)
}; };


class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase {
@@ -45,6 +46,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32)
}; };


class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase {
@@ -60,6 +62,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32)
}; };


class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase {
@@ -75,6 +78,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32)
}; };


class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase {
@@ -90,6 +94,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32)
}; };


//===================== NCHW44 Winograd Support =====================// //===================== NCHW44 Winograd Support =====================//
@@ -107,6 +112,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32)
}; };


class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase {
@@ -123,6 +129,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32)
}; };


class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase {
@@ -139,6 +146,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32)
}; };
// ================================================================= // // ================================================================= //


@@ -157,6 +165,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP32)
}; };


class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
@@ -174,6 +183,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP32)
}; };


class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
@@ -191,6 +201,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_FP32)
}; };


class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
@@ -209,6 +220,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_FP32)
}; };


class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase {
@@ -227,6 +239,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_FP32)
}; };


class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
@@ -244,6 +257,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW44_F32)
}; };


} // namespace arm_common } // namespace arm_common


+ 13
- 0
dnn/src/arm_common/conv_bias/int8/algos.h View File

@@ -33,6 +33,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_S8)
}; };


class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
@@ -49,6 +50,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_S8)
}; };


class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
@@ -65,6 +67,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44)
}; };


class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
@@ -81,6 +84,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_S8)
}; };


class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase {
@@ -95,6 +99,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD1_NCHW44_S8)
}; };


class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase {
@@ -109,6 +114,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8)
}; };


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -126,6 +132,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8)
}; };


class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
@@ -142,6 +149,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_S8)
}; };


class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
@@ -159,6 +167,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_S8)
}; };


class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
@@ -180,6 +189,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_DOT_S8)
}; };
#endif #endif


@@ -196,6 +206,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8)
}; };


//=======================input int8 compute fp32 output int8============ //=======================input int8 compute fp32 output int8============
@@ -213,6 +224,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32)
}; };


//=======================input int8 compute int16 output int8============ //=======================input int8 compute int16 output int8============
@@ -231,6 +243,7 @@ public:
} }


MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8)
}; };


} // namespace arm_common } // namespace arm_common


+ 11
- 2
dnn/src/arm_common/conv_bias/int8x8x16/algos.h View File

@@ -39,6 +39,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_INT8X8X16)
}; };


class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase {
@@ -54,6 +55,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_INT8X8X16)
}; };


class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
@@ -80,6 +82,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_INT8X8X16)
}; };


class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase {
@@ -96,12 +99,16 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16)
}; };


class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final
: public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; }
const char* name() const override {
return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44";
}
bool usable(const NCBKernSizeParam& param, bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const override; AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace( size_t get_workspace(
@@ -111,6 +118,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16)
}; };


class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase {
@@ -129,6 +137,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16)
}; };


} // namespace arm_common } // namespace arm_common


+ 80
- 53
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -88,46 +88,50 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
#endif #endif


SmallVector<std::unique_ptr<AlgoBase>> refhold; SmallVector<std::unique_ptr<AlgoBase>> refhold;
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_direct_algos;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_winograd_algos;


public: public:
AlgoPack() { AlgoPack() {
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
direct_algos.emplace_back(&ds8_direct_stride1);
direct_algos.emplace_back(&ds8_direct_stride2);
direct_algos.emplace_back(&du8_direct_stride1);
direct_algos.emplace_back(&du8_direct_stride2);
m_direct_algos.emplace_back(&ds8_direct_stride1);
m_direct_algos.emplace_back(&ds8_direct_stride2);
m_direct_algos.emplace_back(&du8_direct_stride1);
m_direct_algos.emplace_back(&du8_direct_stride2);


direct_algos.emplace_back(&ds8_direct_nchw44);
direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
m_direct_algos.emplace_back(&ds8_direct_nchw44);
m_direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
#endif #endif
direct_algos.emplace_back(&qu8_direct_stride2);
direct_algos.emplace_back(&qu8_direct_stride1);
direct_algos.emplace_back(&s8_direct_stride2);
direct_algos.emplace_back(&s8_direct_nchw44);
direct_algos.emplace_back(&s8x8x16_direct_nchw44);
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1);

direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);
m_direct_algos.emplace_back(&qu8_direct_stride2);
m_direct_algos.emplace_back(&qu8_direct_stride1);
m_direct_algos.emplace_back(&s8_direct_stride2);
m_direct_algos.emplace_back(&s8_direct_nchw44);
m_direct_algos.emplace_back(&s8x8x16_direct_nchw44);
m_direct_algos.emplace_back(&s8_direct_nchw_nchw44);
m_direct_algos.emplace_back(&s8_direct_stride1);

m_direct_algos.emplace_back(
&s8x8x16_channel_wise_stride1_stride2_nchw44);
m_direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44);
m_direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44);


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
direct_algos.emplace_back(&f16_direct_stride1);
direct_algos.emplace_back(&f16_direct);
m_direct_algos.emplace_back(&f16_direct_stride1);
m_direct_algos.emplace_back(&f16_direct);
#endif #endif
direct_algos.emplace_back(&i8x8x16_direct);
direct_algos.emplace_back(&i8x8x16_stride2_filter2);
direct_algos.emplace_back(&i8x8x16_stride2);
direct_algos.emplace_back(&i8x8x16_nchw_nchw44);
m_direct_algos.emplace_back(&i8x8x16_direct);
m_direct_algos.emplace_back(&i8x8x16_stride2_filter2);
m_direct_algos.emplace_back(&i8x8x16_stride2);
m_direct_algos.emplace_back(&i8x8x16_nchw_nchw44);


direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&f32_chanel_wise_nchw44);
direct_algos.emplace_back(&f32_direct_nchw44);
m_direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44);
m_direct_algos.emplace_back(&f32_chanel_wise_nchw44);
m_direct_algos.emplace_back(&f32_direct_nchw44);


direct_algos.emplace_back(&f32_direct_stride1);
direct_algos.emplace_back(&f32_direct_stride2);
direct_algos.emplace_back(&f32_direct);
m_direct_algos.emplace_back(&f32_direct_stride1);
m_direct_algos.emplace_back(&f32_direct_stride2);
m_direct_algos.emplace_back(&f32_direct);


static CpuOprDelegationStorage<2> storage; static CpuOprDelegationStorage<2> storage;
auto matmul_opr = storage.get<MatrixMul, 0>(); auto matmul_opr = storage.get<MatrixMul, 0>();
@@ -143,31 +147,31 @@ public:
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
//! uncomment this when low precision mode is done //! uncomment this when low precision mode is done
#if 0 #if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
#endif #endif
//! Qint8x8x32 winograd compute with fp32 //! Qint8x8x32 winograd compute with fp32
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
} }
} }
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
@@ -180,15 +184,15 @@ public:
refhold.emplace_back(new AlgoFP32WinogradF63( refhold.emplace_back(new AlgoFP32WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF54( refhold.emplace_back(new AlgoFP32WinogradF54(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF45( refhold.emplace_back(new AlgoFP32WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
} }
} }


@@ -203,15 +207,15 @@ public:
refhold.emplace_back(new AlgoFP16WinogradF23( refhold.emplace_back(new AlgoFP16WinogradF23(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP16WinogradF45( refhold.emplace_back(new AlgoFP16WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP16WinogradF63( refhold.emplace_back(new AlgoFP16WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
} }
} }
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
@@ -224,7 +228,7 @@ public:
refhold.emplace_back(new AlgoFP16WinogradF23_8x8( refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
} }
} }
#endif #endif
@@ -238,25 +242,48 @@ public:
refhold.emplace_back(new AlgoS8WinogradF23_8x8( refhold.emplace_back(new AlgoS8WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get());
m_winograd_algos.emplace_back(refhold.back().get());
} }
} }


for (auto&& algo : m_direct_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
for (auto&& algo : m_winograd_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }
SmallVector<AlgoBase*> direct_algos;
SmallVector<AlgoBase*> winograd_algos;

const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos()
const {
return m_direct_algos;
}
const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos()
const {
return m_winograd_algos;
}
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = fallback::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end());
algos.insert(algos.end(), sl_algo_pack.winograd_algos.begin(),
sl_algo_pack.winograd_algos.end());
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)

SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().direct_algos().begin(),
algo_pack().direct_algos().end());
algos.insert(algos.end(), algo_pack().winograd_algos().begin(),
algo_pack().winograd_algos().end());
return std::move(algos); return std::move(algos);
} }




+ 7
- 2
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/opr_impl.h"
#include "src/common/algo_base.h"


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
@@ -27,7 +28,7 @@ public:
} }
}; };


SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override;


bool is_matmul_quantized_prefer( bool is_matmul_quantized_prefer(
const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param)
@@ -35,7 +36,8 @@ public:


SmallVector<AlgoCategory> suggest_algo_category_order( SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
class AlgoPack;

MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl);


protected: protected:
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
@@ -95,6 +97,9 @@ private:
class AlgoF16Direct; class AlgoF16Direct;
class AlgoF16DirectStride1; class AlgoF16DirectStride1;
#endif #endif

class AlgoPack;
static const AlgoPack& algo_pack();
}; };


} // namespace arm_common } // namespace arm_common


+ 4
- 0
dnn/src/arm_common/conv_bias/quint8/algos.h View File

@@ -32,6 +32,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_QU8)
}; };


class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase {
@@ -48,6 +49,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8)
}; };
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase {
@@ -65,6 +67,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8)
}; };


class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase {
@@ -81,6 +84,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8)
}; };
#endif #endif
} // namespace arm_common } // namespace arm_common


+ 2
- 0
dnn/src/arm_common/convolution/int8x8x32/algos.h View File

@@ -36,6 +36,7 @@ public:


ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32)
}; };


class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final
@@ -54,6 +55,7 @@ public:


ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32)
}; };


#endif #endif


+ 4
- 0
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp View File

@@ -1086,6 +1086,10 @@ bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1; FH >= PH + 1 && FW >= PW + 1;


avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32);
return avaiable && return avaiable &&
((FH == 2 && OC <= 8) || ((FH == 2 && OC <= 8) ||
((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16)));


+ 4
- 0
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp View File

@@ -1180,6 +1180,10 @@ bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1; FH >= PH + 1 && FW >= PW + 1;


avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32);
return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) || return avaiable && ((FH == 2 && OC <= 4) || (FH == 3 && OC <= 8) ||
(FH == 5 && OC <= 16) || (FH == 7 && OC < 32)); (FH == 5 && OC <= 16) || (FH == 7 && OC < 32));
} }


+ 41
- 31
dnn/src/arm_common/convolution/opr_impl.cpp View File

@@ -23,15 +23,54 @@ using namespace arm_common;




/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */
struct ConvolutionBackwardDataImpl::AlgoPack {
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot;
AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot;
AlgoUdot8DirectStride1 quint8_direct_stride1_udot; AlgoUdot8DirectStride1 quint8_direct_stride1_udot;
AlgoUdot8DirectStride2 quint8_direct_stride2_udot; AlgoUdot8DirectStride2 quint8_direct_stride2_udot;
#endif #endif

fallback::ConvolutionBackwardDataImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>
m_all_algos;

public:
AlgoPack() {
#if __ARM_FEATURE_DOTPROD
m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot);
m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot);
m_all_algos.emplace_back(&quint8_direct_stride1_udot);
m_all_algos.emplace_back(&quint8_direct_stride2_udot);
#endif

for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}

const SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>&
all_algos() const {
return m_all_algos;
}
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };
ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;

const ConvolutionBackwardDataImpl::AlgoPack&
ConvolutionBackwardDataImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl)

SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>
ConvolutionBackwardDataImpl::get_all_packed_algo() {
auto&& algos = fallback::ConvolutionBackwardDataImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos);
}


ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
@@ -52,35 +91,6 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
param); param);
} }


std::vector<ConvolutionBackwardDataImpl::Algorithm*>
ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
const NCBKernSizeParam& param) {
auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
param);

#if __ARM_FEATURE_DOTPROD
if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32)) {
if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot);
}
if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot);
}
} else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.grad_type.enumv() == DTypeEnum::QuantizedS32) {
if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot);
}
if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot);
}
}
#endif
return ret;
}
const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
// arm common version 0 // arm common version 0
return "DeconvAC0"; return "DeconvAC0";


+ 8
- 5
dnn/src/arm_common/convolution/opr_impl.h View File

@@ -47,11 +47,14 @@ protected:
size_t ncb_1g_get_workspace(Algorithm* algo, size_t ncb_1g_get_workspace(Algorithm* algo,
const NCBKernSizeParam& param) override; const NCBKernSizeParam& param) override;


std::vector<Algorithm*> ncb_1g_get_all_algorithms(
const NCBKernSizeParam& param) override;

const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*>
get_all_packed_algo() override;

public:
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl);

private: private:
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
class AlgoSdot8DirectStride1; class AlgoSdot8DirectStride1;
@@ -59,8 +62,8 @@ private:
class AlgoUdot8DirectStride1; class AlgoUdot8DirectStride1;
class AlgoUdot8DirectStride2; class AlgoUdot8DirectStride2;
#endif #endif
struct AlgoPack;
static AlgoPack sm_algo_pack;
class AlgoPack;
static const AlgoPack& algo_pack();
}; };


} // namespace arm_common } // namespace arm_common


+ 2
- 0
dnn/src/arm_common/convolution/quint8/algos.h View File

@@ -36,6 +36,7 @@ public:
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;


MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8)
}; };


class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final
@@ -55,6 +56,7 @@ public:
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;


MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8)
}; };
#endif #endif
} // namespace arm_common } // namespace arm_common


+ 3
- 0
dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp View File

@@ -1236,6 +1236,9 @@ bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1; FH >= PH + 1 && FW >= PW + 1;


avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm ||
param.grad_type.enumv() == DTypeEnum::Int32);

/** /**
* \note In the kernel, we use int32_t to calc the value, in order * \note In the kernel, we use int32_t to calc the value, in order
* not generate negative number, we first initialize SHIFT and sub * not generate negative number, we first initialize SHIFT and sub


+ 3
- 0
dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp View File

@@ -1337,6 +1337,9 @@ bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) {
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && (FH == 2 || FH == 3 || FH == 5 || FH == 7) &&
FH >= PH + 1 && FW >= PW + 1; FH >= PH + 1 && FW >= PW + 1;


avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm ||
param.grad_type.enumv() == DTypeEnum::Int32);

/** /**
* \note In the kernel, we use uint32_t to calc the value, in order * \note In the kernel, we use uint32_t to calc the value, in order
* not generate negative number, we first initialize SHIFT and sub * not generate negative number, we first initialize SHIFT and sub


+ 1
- 0
dnn/src/arm_common/elemwise/opr_impl.h View File

@@ -59,6 +59,7 @@ public:
virtual bool is_available(const KernParam&) const = 0; virtual bool is_available(const KernParam&) const = 0;
virtual void exec(const KernParam&) const = 0; virtual void exec(const KernParam&) const = 0;
virtual ~AlgoBase() = default; virtual ~AlgoBase() = default;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
}; };


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC


+ 7
- 0
dnn/src/arm_common/matrix_mul/algos.h View File

@@ -26,6 +26,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X16)
}; };


class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
@@ -39,6 +40,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV)
}; };


class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
@@ -52,6 +54,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4)
}; };


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -66,6 +69,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4_DOT)
}; };
#endif #endif


@@ -96,6 +100,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4)
}; };


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -110,6 +115,7 @@ public:
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F16_GEMV)
}; };
#endif #endif


@@ -130,6 +136,7 @@ public:
static_cast<uint32_t>(AlgoDataType::FLOAT32) | static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32)), static_cast<uint32_t>(AlgoDataType::QINT8X8X32)),
DEFAULT) DEFAULT)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_GEVM)
}; };


} // namespace arm_common } // namespace arm_common


+ 31
- 12
dnn/src/arm_common/matrix_mul/opr_impl.cpp View File

@@ -28,28 +28,47 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoGevm gevm; AlgoGevm gevm;
AlgoF32GemvMK4 f32_gemv_mk4; AlgoF32GemvMK4 f32_gemv_mk4;


SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;

public: public:
AlgoPack() { AlgoPack() {
all_algos.emplace_back(&int8x8x16);
m_all_algos.emplace_back(&int8x8x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16gemv);
m_all_algos.emplace_back(&f16gemv);
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_gemv_mk4_dot);
m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot);
#endif #endif
all_algos.emplace_back(&int8x8x32_gemv);
all_algos.emplace_back(&int8x8x32_gemv_mk4);
all_algos.emplace_back(&f32_gemv_mk4);
all_algos.emplace_back(&gevm);
m_all_algos.emplace_back(&int8x8x32_gemv);
m_all_algos.emplace_back(&int8x8x32_gemv_mk4);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&gevm);

for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}

const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const {
return m_all_algos;
} }
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl)

SmallVector<fallback::MatrixMulImpl::AlgoBase*>
MatrixMulImpl::get_all_packed_algo() {
static AlgoPack s_algo_pack; static AlgoPack s_algo_pack;
auto&& algos = fallback::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),
s_algo_pack.all_algos.end());
auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos); return std::move(algos);
} }




+ 8
- 1
dnn/src/arm_common/matrix_mul/opr_impl.h View File

@@ -11,6 +11,7 @@
#pragma once #pragma once
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/matrix_mul/opr_impl.h" #include "src/fallback/matrix_mul/opr_impl.h"
#include "src/common/algo_base.h"


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
@@ -27,7 +28,10 @@ public:
} }
}; };


SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo()
override;

MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);


protected: protected:
class AlgoF32Gemv; // Arm_common F32 Gemv class AlgoF32Gemv; // Arm_common F32 Gemv
@@ -43,6 +47,9 @@ protected:
#endif #endif
class AlgoInt8x8x16; // Arm_common Int 8x8x16 class AlgoInt8x8x16; // Arm_common Int 8x8x16
class AlgoPack; class AlgoPack;

public:
static const AlgoPack& algo_pack();
}; };


} // namespace arm_common } // namespace arm_common


+ 3
- 0
dnn/src/arm_common/pooling/opr_impl.h View File

@@ -10,6 +10,7 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#include "megdnn/oprs/base.h"
#include "src/fallback/pooling/opr_impl.h" #include "src/fallback/pooling/opr_impl.h"


namespace megdnn { namespace megdnn {
@@ -72,6 +73,8 @@ public:
virtual ~AlgoBase() = default; virtual ~AlgoBase() = default;
virtual bool usable(const PoolingKernSizeParam& param) const = 0; virtual bool usable(const PoolingKernSizeParam& param) const = 0;
virtual void exec(const PoolingKernParam& param) const = 0; virtual void exec(const PoolingKernParam& param) const = 0;

uint32_t type() const override { return INVALID_ALGO_TYPE; };
}; };


private: private:


+ 1
- 0
dnn/src/armv7/conv_bias/int8/algos.h View File

@@ -40,6 +40,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
} }
MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_S8)
}; };


} // namespace armv7 } // namespace armv7


+ 26
- 8
dnn/src/armv7/conv_bias/opr_impl.cpp View File

@@ -24,22 +24,40 @@ using namespace armv7;
class ConvBiasImpl::AlgoPack : NonCopyableObj { class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8MatrixMul s8_matrix_mul; AlgoS8MatrixMul s8_matrix_mul;
AlgoQU8MatrixMul qu8_matrix_mul; AlgoQU8MatrixMul qu8_matrix_mul;
fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> m_all_algos;
public: public:
AlgoPack() { AlgoPack() {
all_algos.emplace_back(&qu8_matrix_mul);
all_algos.emplace_back(&s8_matrix_mul);
m_all_algos.emplace_back(&qu8_matrix_mul);
m_all_algos.emplace_back(&s8_matrix_mul);

for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}

const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& all_algos()
const {
return m_all_algos;
} }
SmallVector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack();
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl)

SmallVector<fallback::ConvBiasImpl::AlgoBase*>
ConvBiasImpl::get_all_packed_algo() {
auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo();
//! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now,
//! and nearly equal in aarch64, because of the waste of register in //! and nearly equal in aarch64, because of the waste of register in
//! postprocess //! postprocess
algos.insert(algos.end(), sl_algo_pack.all_algos.begin(),
sl_algo_pack.all_algos.end());
algos.insert(algos.end(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return std::move(algos); return std::move(algos);
} }




+ 4
- 1
dnn/src/armv7/conv_bias/opr_impl.h View File

@@ -25,7 +25,9 @@ public:
} }
}; };


SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override;

MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvBiasImpl);


protected: protected:
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;
@@ -34,6 +36,7 @@ private:
class AlgoS8MatrixMul; class AlgoS8MatrixMul;
class AlgoQU8MatrixMul; class AlgoQU8MatrixMul;
class AlgoPack; class AlgoPack;
static const AlgoPack& algo_pack();
}; };


} // namespace armv7 } // namespace armv7


+ 1
- 0
dnn/src/armv7/conv_bias/quint8/algos.h View File

@@ -42,6 +42,7 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
} }
MEGDNN_DECL_ALGO_TYPE(ARMV7_MATMUL_QU8)
}; };


} // namespace armv7 } // namespace armv7


+ 24
- 1
dnn/src/armv7/matrix_mul/algos.h View File

@@ -27,6 +27,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32)
}; };


class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase {
@@ -37,6 +38,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_PACK_4X12)
}; };


class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase {
@@ -48,6 +50,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_MK4_4x8)
}; };


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -59,6 +62,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_K4X16X1)
}; };
class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase {
public: public:
@@ -69,6 +73,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8)
}; };
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -80,6 +85,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_K6X8X4)
}; };


class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase {
@@ -90,6 +96,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X4)
}; };


class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase {
@@ -102,11 +109,18 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8_MK4_8X4X4_DOTPROD)
}; };
#endif #endif


class MatrixMulImpl::AlgoF32Gemv final class MatrixMulImpl::AlgoF32Gemv final
: public arm_common::MatrixMulImpl::AlgoF32Gemv {};
: public arm_common::MatrixMulImpl::AlgoF32Gemv {
public:
AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() {
m_handle_type = Handle::HandleType::ARMV7;
}
MEGDNN_DECL_ALGO_TYPE(ARMV7_F32_GEMV)
};


class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase {
public: public:
@@ -117,6 +131,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X2X16)
}; };


class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase {
@@ -128,6 +143,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_K4X8X8)
}; };


class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase {
@@ -138,6 +154,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_QUINT8_K4X8X8)
}; };


class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase {
@@ -149,6 +166,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X2X16)
}; };


class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase {
@@ -160,6 +178,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_K4X8X8)
}; };


class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
@@ -171,6 +190,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X16_MK4_K8X8X4)
}; };


class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase {
@@ -182,6 +202,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_K12X4X1)
}; };


class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase {
@@ -193,6 +214,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT16X16X32_MK8_4X8)
}; };


class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
@@ -204,6 +226,7 @@ public:
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(ARMV7_INT8X8X32_MK4_4X2X16)
}; };


} // namespace armv7 } // namespace armv7


+ 42
- 24
dnn/src/armv7/matrix_mul/opr_impl.cpp View File

@@ -43,42 +43,60 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1;
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8;


SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;

public: public:
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;


AlgoPack() { AlgoPack() {
all_algos.emplace_back(&f32_gemv);
all_algos.emplace_back(&f32);
all_algos.emplace_back(&f32_mk4_pack_4x12);
all_algos.emplace_back(&f32_mk4_4x8);
m_all_algos.emplace_back(&f32_gemv);
m_all_algos.emplace_back(&f32);
m_all_algos.emplace_back(&f32_mk4_pack_4x12);
m_all_algos.emplace_back(&f32_mk4_4x8);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16_k4x16x1);
all_algos.emplace_back(&f16_mk8_4x8);
m_all_algos.emplace_back(&f16_k4x16x1);
m_all_algos.emplace_back(&f16_mk8_4x8);
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod);
all_algos.emplace_back(&int8_k6x8x4);
all_algos.emplace_back(&quint8_k4x8x4);
m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod);
m_all_algos.emplace_back(&int8_k6x8x4);
m_all_algos.emplace_back(&quint8_k4x8x4);
#endif #endif
all_algos.emplace_back(&int8x8x32_mk4_4x2x16);
all_algos.emplace_back(&int8x8x32_k4x2x16);
all_algos.emplace_back(&int8x8x32_k4x8x8);
all_algos.emplace_back(&quint8_k4x8x8);
all_algos.emplace_back(&int8x8x16_mk4_8x8x4);
all_algos.emplace_back(&int8x8x16_k4x2x16);
all_algos.emplace_back(&int8x8x16_k4x8x8);
m_all_algos.emplace_back(&int8x8x32_mk4_4x2x16);
m_all_algos.emplace_back(&int8x8x32_k4x2x16);
m_all_algos.emplace_back(&int8x8x32_k4x8x8);
m_all_algos.emplace_back(&quint8_k4x8x8);
m_all_algos.emplace_back(&int8x8x16_mk4_8x8x4);
m_all_algos.emplace_back(&int8x8x16_k4x2x16);
m_all_algos.emplace_back(&int8x8x16_k4x8x8);

m_all_algos.emplace_back(&int16x16x32_k12x4x1);
m_all_algos.emplace_back(&int16x16x32_mk8_4x8);

for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}


all_algos.emplace_back(&int16x16x32_k12x4x1);
all_algos.emplace_back(&int16x16x32_mk8_4x8);
const SmallVector<fallback::MatrixMulImpl::AlgoBase*>& all_algos() const {
return m_all_algos;
} }
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack;
auto algos = arm_common::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),
s_algo_pack.all_algos.end());
const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

SmallVector<fallback::MatrixMulImpl::AlgoBase*>
MatrixMulImpl::get_all_packed_algo() {
auto algos = arm_common::MatrixMulImpl::get_all_packed_algo();
algos.insert(algos.begin(), algo_pack().all_algos().begin(),
algo_pack().all_algos().end());
return algos; return algos;
} }


MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl)

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 7
- 1
dnn/src/armv7/matrix_mul/opr_impl.h View File

@@ -25,7 +25,10 @@ public:
} }
}; };


SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo()
override;

MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);


private: private:
class AlgoF32; // Armv7 F32 class AlgoF32; // Armv7 F32
@@ -52,6 +55,9 @@ private:
// DotProduct // DotProduct
#endif #endif
class AlgoPack; class AlgoPack;

public:
static const AlgoPack& algo_pack();
}; };


} // namespace armv7 } // namespace armv7


+ 101
- 0
dnn/src/common/algo_base.h View File

@@ -0,0 +1,101 @@
/**
* \file dnn/src/common/algo_base.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include <functional>
#include <string>

#include "megdnn/oprs/base.h"
#include "src/common/utils.h"

namespace megdnn {

#define MEGDNN_DECL_ALGO_TYPE(_type) \
uint32_t type() const override { \
return static_cast<std::underlying_type<AlgoType>::type>( \
AlgoType::_type); \
}

#define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \
static fallback::_opr::AlgoBase* get_algo_from_desc( \
const AlgorithmDesc& desc)

#define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \
fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \
const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}

#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \
megdnn_assert(algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}

/**
* \brief construct algo from AlgorithmDesc
*/
template <typename AlgoBase>
class AlgoConstructMixin {
private:
std::vector<std::unique_ptr<AlgoBase>> m_refhold;
protected:
typename AlgoBase::Mapper m_all_algos_map;

public:

//! construct the algo which described by desc, and return the instance
AlgoBase* construct_and_get_algo(
const detail::Algorithm::Info::Desc& desc) {
auto iter = m_all_algos_map.find(desc);
if (iter != m_all_algos_map.end()) {
return m_all_algos_map.at(desc);
}
std::string serialized_bin;
AlgoBase::serialize_write_pod(desc.type, serialized_bin);
serialized_bin += desc.param;
m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin));
m_all_algos_map.emplace(desc, m_refhold.back().get());
return m_refhold.back().get();
}

void clear() {
m_all_algos_map.clear();
m_refhold.clear();
}

const typename AlgoBase::Mapper& all_algos_map() const {
return m_all_algos_map;
}
};

} // namespace megdnn

namespace std {
template <>
struct hash<megdnn::detail::Algorithm::Info::Desc> {
std::size_t operator()(
const megdnn::detail::Algorithm::Info::Desc& desc) const {
return megdnn::hash_combine<size_t>(
megdnn::hash_combine<size_t>(
std::hash<std::string>()(desc.param),
std::hash<uint32_t>()(desc.type)),
std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type)));
}
};
} // namespace std

// vim: syntax=cpp.doxygen

+ 25
- 6
dnn/src/common/algo_chooser.h View File

@@ -25,15 +25,34 @@ namespace megdnn {
*/ */
template <class Opr, typename... Args> template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
typename Opr::Algorithm* ret;
if (auto set = opr->execution_policy().algorithm) {
typename Opr::AlgorithmInfo ret;
auto set = opr->execution_policy().algo;
if (set.valid()) {
ret = set; ret = set;
} else { } else {
ret = opr->get_algorithm_heuristic(std::forward<Args>(args)...,
std::numeric_limits<size_t>::max(),
false);
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
false);
}
return opr->get_algo_from_desc(ret.desc);
}

/*!
* \brief get user-configured algorithm, or heuristic algorithm. used in opencl
* whose algo need to be constructed each time.
*/
template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
typename Opr::AlgorithmInfo ret;
auto set = opr->execution_policy().algo;
if (set.valid()) {
return opr->algo_pack().construct_and_get_algo(set.desc);
} else {
ret = opr->get_algorithm_info_heuristic(
std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
false);
return opr->get_algo_from_desc(ret.desc);
} }
return static_cast<typename Opr::AlgoBase*>(ret);
} }


/*! /*!


+ 33
- 0
dnn/src/common/utils.h View File

@@ -9,6 +9,32 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


/**
* Boost Software License - Version 1.0 - August 17th, 2003
*
* Permission is hereby granted, free of charge, to any person or organization
* obtaining a copy of the software and accompanying documentation covered by
* this license (the "Software") to use, reproduce, display, distribute,
* execute, and transmit the Software, and to prepare derivative works of the
* Software, and to permit third-parties to whom the Software is furnished to
* do so, all subject to the following:
*
* The copyright notices in the Software and this entire statement, including
* the above license grant, this restriction and the following disclaimer,
* must be included in all copies of the Software, in whole or in part, and
* all derivative works of the Software, unless such copies or derivative
* works are solely in the form of machine-executable object code generated by
* a source language processor.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
* SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
* FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/

#pragma once #pragma once


#include "megdnn/arch.h" #include "megdnn/arch.h"
@@ -263,6 +289,13 @@ constexpr uint32_t operator"" _hash(char const* str, size_t count) {
return XXHash64CT::hash(str, count, 20160701); return XXHash64CT::hash(str, count, 20160701);
} }


// refer to https://www.boost.org/doc/libs/1_64_0/boost/functional/hash/hash.hpp
template <typename T>
inline T hash_combine(T seed, T value) {
seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2);
return seed;
}

template <typename Vec> template <typename Vec>
std::string vec2str(Vec&& vec) { std::string vec2str(Vec&& vec) {
std::string res; std::string res;


+ 6
- 0
dnn/src/cuda/batch_conv_bias/algo.cpp View File

@@ -18,8 +18,14 @@ using namespace cuda;
BatchConvBiasForwardImpl::AlgoPack::AlgoPack() { BatchConvBiasForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&int8_nchw4_gemm_dotprod); all_algos.push_back(&int8_nchw4_gemm_dotprod);
all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod); all_algos.push_back(&int8_nchw4_implicit_gemm_dotprod);

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchConvBiasForwardImpl)

BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack; BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack;


BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(


+ 17
- 4
dnn/src/cuda/batch_conv_bias/algo.h View File

@@ -11,13 +11,16 @@


#pragma once #pragma once


#include <csetjmp>
#include <unordered_map>
#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/batch_conv_bias/opr_impl.h" #include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"


#include "src/common/algo_base.h"
#include "src/common/metahelper.h"

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


@@ -26,6 +29,12 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_GEMM_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
BatchConvBiasForwardImpl* opr; BatchConvBiasForwardImpl* opr;
@@ -85,6 +94,7 @@ public:
const char* name() const override { const char* name() const override {
return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD";
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_GEMM_NCHW4_DOTPROD_INT8)
}; };


class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp final
@@ -99,15 +109,16 @@ public:
const char* name() const override { const char* name() const override {
return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD";
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8)


private: private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const; const SizeArgs& args) const;
}; };


class BatchConvBiasForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class BatchConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;


public: public:
AlgoPack(); AlgoPack();
@@ -116,6 +127,8 @@ public:
AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod; AlgoInt8NCHW4DotProdImplicitGemmPrecomp int8_nchw4_implicit_gemm_dotprod;


std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 12
- 9
dnn/src/cuda/batch_conv_bias/opr_impl.h View File

@@ -26,6 +26,18 @@ public:
const TensorLayout& bias, const TensorLayout& bias,
const TensorLayout& z, const TensorLayout& z,
const TensorLayout& dst) override; const TensorLayout& dst) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoInt8NCHW4DotProdGemm;
class AlgoInt8NCHW4DotProdImplicitGemmPrecomp;

class AlgoPack;

static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms( std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
@@ -37,15 +49,6 @@ public:
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; bool reproducible) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoInt8NCHW4DotProdGemm;
class AlgoInt8NCHW4DotProdImplicitGemmPrecomp;

class AlgoPack;

static const AlgoPack& algo_pack() { return sm_algo_pack; }


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;


+ 8
- 0
dnn/src/cuda/batched_matrix_mul/algo.cpp View File

@@ -60,4 +60,12 @@ BatchedMatrixMulForwardImpl::AlgoPack::AlgoPack() {
for (auto& algo : brute_force_algos) { for (auto& algo : brute_force_algos) {
all_algos.push_back(&algo); all_algos.push_back(&algo);
} }

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }

MEGDNN_DEF_GET_ALGO_FROM_DESC(BatchedMatrixMulForwardImpl)

// vim: syntax=cpp.doxygen

+ 26
- 3
dnn/src/cuda/batched_matrix_mul/algo.h View File

@@ -16,6 +16,8 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/batched_matrix_mul/opr_impl.h" #include "src/cuda/batched_matrix_mul/opr_impl.h"
#include "src/cuda/matrix_mul/cublasLt_wrapper.h" #include "src/cuda/matrix_mul/cublasLt_wrapper.h"
#include "src/common/metahelper.h"

#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
#include <cublasLt.h> #include <cublasLt.h>
#endif #endif
@@ -28,6 +30,14 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_BRUTE_FORCE,
CUDA_CUBLAS,
CUDA_CUBLASLT,
CUDA_INT8X8X32,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
BatchedMatrixMulForwardImpl* opr; BatchedMatrixMulForwardImpl* opr;
@@ -90,6 +100,13 @@ public:
void exec(const ExecArgs& args) const final; void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE)

std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}
}; };
class BatchedMatrixMulForwardImpl::AlgoCublas final class BatchedMatrixMulForwardImpl::AlgoCublas final
: public BatchedMatrixMulForwardImpl::AlgoBase { : public BatchedMatrixMulForwardImpl::AlgoBase {
@@ -100,6 +117,7 @@ public:
void exec(const ExecArgs& args) const final; void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "CUBLAS"; } const char* name() const override { return "CUBLAS"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
}; };
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase { class BatchedMatrixMulForwardImpl::AlgoCublasLt final : public AlgoBase {
@@ -110,6 +128,7 @@ public:
void exec(const ExecArgs& args) const final; void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "CUBLAS_LT"; } const char* name() const override { return "CUBLAS_LT"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT)
}; };
#endif #endif
class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final class BatchedMatrixMulForwardImpl::AlgoInt8x8x32 final
@@ -121,11 +140,13 @@ public:
void exec(const ExecArgs& args) const final; void exec(const ExecArgs& args) const final;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "INT8x8x32"; } const char* name() const override { return "INT8x8x32"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32)
}; };
class BatchedMatrixMulForwardImpl::AlgoPack {

class BatchedMatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
MatrixMulForwardImpl::AlgoPack mm_pack; MatrixMulForwardImpl::AlgoPack mm_pack;
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;


public: public:
AlgoPack(); AlgoPack();
@@ -137,6 +158,8 @@ public:
AlgoInt8x8x32 int8x8x32; AlgoInt8x8x32 int8x8x32;
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;
std::vector<AlgoBruteForce> brute_force_algos; std::vector<AlgoBruteForce> brute_force_algos;

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn

+ 3
- 3
dnn/src/cuda/batched_matrix_mul/brute_force.cpp View File

@@ -24,7 +24,7 @@ bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
MatrixMulForwardImpl mm{args.opr->handle()}; MatrixMulForwardImpl mm{args.opr->handle()};
mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB}; mm.param() = {args.opr->param().transposeA, args.opr->param().transposeB};
mm.execution_policy() = {m_algorithm};
mm.execution_policy() = {m_algorithm->info()};


auto mm_layout_a = args.layout_a.remove_axis(0); auto mm_layout_a = args.layout_a.remove_axis(0);
auto mm_layout_b = args.layout_b.remove_axis(0); auto mm_layout_b = args.layout_b.remove_axis(0);
@@ -39,7 +39,7 @@ size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes(
auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); auto mm_opr = args.opr->handle()->create_operator<MatrixMulForward>();
mm_opr->param() = {args.opr->param().transposeA, mm_opr->param() = {args.opr->param().transposeA,
args.opr->param().transposeB}; args.opr->param().transposeB};
mm_opr->execution_policy() = {m_algorithm};
mm_opr->execution_policy() = {m_algorithm->info()};


return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b, return mm_opr->get_workspace_in_bytes(args.layout_a, args.layout_b,
args.layout_c); args.layout_c);
@@ -50,7 +50,7 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(
auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>(); auto&& mm_opr = args.opr->handle()->create_operator<MatrixMulForward>();
mm_opr->param() = {args.opr->param().transposeA, mm_opr->param() = {args.opr->param().transposeA,
args.opr->param().transposeB}; args.opr->param().transposeB};
mm_opr->execution_policy() = {m_algorithm};
mm_opr->execution_policy() = {m_algorithm->info()};
rep(n, N) { rep(n, N) {
TensorND A_, B_, C_; TensorND A_, B_, C_;
auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) {


+ 10
- 6
dnn/src/cuda/batched_matrix_mul/opr_impl.h View File

@@ -32,6 +32,16 @@ public:
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C) override; const TensorLayout& C) override;

const char* get_algorithm_set_name() const override {
return "BATCHED_MATMUL";
}

bool is_thread_safe() const override { return true; }
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A, std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B, const TensorLayout& B,
const TensorLayout& C) override; const TensorLayout& C) override;
@@ -40,12 +50,6 @@ public:
const TensorLayout& C, const TensorLayout& C,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; bool reproducible) override;
const char* get_algorithm_set_name() const override {
return "BATCHED_MATMUL";
}

bool is_thread_safe() const override { return true; }
static const AlgoPack& algo_pack() { return sm_algo_pack; }


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;


+ 10
- 37
dnn/src/cuda/conv_bias/algo.cpp View File

@@ -100,10 +100,16 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
for (size_t i = all_algo_size; i < all_algos.size(); ++i) { for (size_t i = all_algo_size; i < all_algos.size(); ++i) {
non_cudnn_algos.push_back(all_algos[i]); non_cudnn_algos.push_back(all_algos[i]);
} }

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack; ConvBiasForwardImpl::AlgoPack ConvBiasForwardImpl::sm_algo_pack;


MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl)

ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs(
ConvBiasForwardImpl* o, const TensorLayout& src, ConvBiasForwardImpl* o, const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& filter, const TensorLayout& bias,
@@ -172,43 +178,10 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const {
} }


void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)

#define DEF_ALGO(NAME, REPROD) \
cudnn_conv_bias_activations.push_back( \
{REPROD, \
"CUDNN:ConvBiasActivation:" #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \
NAME}); \
cudnn_convs.push_back( \
{REPROD, \
"CUDNN:Convolution:" #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL), \
NAME})

DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true);

#if CUDNN_MAJOR >= 5
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true);
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true);
#endif
#endif

#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif

#undef DEF_ALGO

#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) {
cudnn_conv_bias_activations.push_back(algo.first);
cudnn_convs.push_back(algo.first);
}
} }


#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000


+ 160
- 31
dnn/src/cuda/conv_bias/algo.h View File

@@ -6,19 +6,23 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once


#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/algo_base.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/metahelper.h"
#include "src/cuda/conv_bias/conv_bias_int8.cuh" #include "src/cuda/conv_bias/conv_bias_int8.cuh"
#include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/helper.h"
#include "src/cuda/conv_bias/opr_impl.h" #include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/cudnn_wrapper.h"


#include <cuda.h> #include <cuda.h>
#include <memory> #include <memory>
@@ -38,11 +42,39 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_CUDNN_CONVBIAS,
CUDA_CHANWISE,
CUDA_CHANWISE_SMALL,
CUDA_CHANWISE_INT8X8X32,
CUDA_CUDNN_CONV,
CUDA_INPLACE_MATMUL,
CUDA_MATMUL,
CUDA_MATMUL_INT8X8X32,
CUDA_1X1,
CUDA_BATCHED_MATMUL,
CUDA_GROUP_CONV_GENERAL,
CUDA_WMMA_UINT4X4X32,
CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8,
CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8,
CUDA_BFLOAT16,
CUDA_IMPLICIT_GEMM_SASS_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW4_DOTPROD_INT8,
CUDA_IMPLICIT_GEMM_SASS_NCHW32_IMMA_INT8,
CUDA_IMPLICIT_GEMM_1X1_SASS_NCHW32_IMMA_INT8,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs : public conv_bias::BiasForwardSizeArgs { struct SizeArgs : public conv_bias::BiasForwardSizeArgs {
ConvBiasForwardImpl* opr; ConvBiasForwardImpl* opr;
const PreprocessedFilter* preprocessed_filter; const PreprocessedFilter* preprocessed_filter;
std::string to_string() const; std::string to_string() const;
SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src, SizeArgs(ConvBiasForwardImpl* opr, const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& filter, const TensorLayout& bias,
@@ -80,13 +112,17 @@ public:
virtual void exec(const ExecArgs& args) const = 0; virtual void exec(const ExecArgs& args) const = 0;
virtual size_t get_preprocess_workspace_in_bytes( virtual size_t get_preprocess_workspace_in_bytes(
const SizeArgs& args) const { const SizeArgs& args) const {
MEGDNN_MARK_USED_VAR(args);
return 0; return 0;
} }
virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout( virtual SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const { const SizeArgs& args) const {
MEGDNN_MARK_USED_VAR(args);
return {}; return {};
} }
virtual void exec_preprocess(const ExecArgs& args) const {}
virtual void exec_preprocess(const ExecArgs& args) const {
MEGDNN_MARK_USED_VAR(args);
}


bool is_available_wk(const SizeArgs& args, size_t limit) { bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
@@ -114,11 +150,14 @@ public:


class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase {
public: public:
AlgoCUDNNConvBiasActivation(bool is_reproducible, const char* name,
cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_is_reproducible(is_reproducible),
m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})),
m_cudnn_enum(cudnn_enum) {}
AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_fwd_algos().end());
m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
m_name = ConvBiasForward::algo_name<DefaultParam>(
"CUDNN:ConvBiasActivation:" + m_attr.name, {});
}


size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
@@ -127,16 +166,24 @@ public:


const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }


bool is_reproducible() const override { return m_is_reproducible; }
bool is_reproducible() const override { return m_attr.is_reproducible; }


cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; }


bool is_cudnn() const override { return true; } bool is_cudnn() const override { return true; }


MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS)

std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}

private: private:
bool m_is_reproducible;
std::string m_name; std::string m_name;
cudnnConvolutionFwdAlgo_t m_cudnn_enum; cudnnConvolutionFwdAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;
}; };


class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase {
@@ -154,6 +201,8 @@ public:
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }


MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)

private: private:
mutable std::string m_name; mutable std::string m_name;
}; };
@@ -172,6 +221,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)


private: private:
mutable std::string m_name; mutable std::string m_name;
@@ -190,6 +240,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32)


private: private:
mutable std::string m_name; mutable std::string m_name;
@@ -197,27 +248,39 @@ private:


class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase {
public: public:
AlgoCUDNNConv(bool is_reproducible, const char* name,
cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_is_reproducible(is_reproducible),
m_name(ConvBiasForward::algo_name<DefaultParam>(name, {})),
m_cudnn_enum(cudnn_enum) {}
AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_fwd_algos().end());
m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum);
m_name = ConvBiasForward::algo_name<DefaultParam>(
"CUDNN:Convolution:" + m_attr.name, {});
}


bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;


bool is_reproducible() const override { return m_is_reproducible; }
bool is_reproducible() const override { return m_attr.is_reproducible; }


const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }


cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }


bool is_cudnn() const override { return true; } bool is_cudnn() const override { return true; }

MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV)

std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}

private: private:
bool m_is_reproducible;
std::string m_name; std::string m_name;
cudnnConvolutionFwdAlgo_t m_cudnn_enum; cudnnConvolutionFwdAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;


WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
}; };
@@ -237,6 +300,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)


private: private:
mutable std::string m_name; mutable std::string m_name;
@@ -261,6 +325,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)


private: private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
@@ -281,6 +346,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32)


private: private:
bool need_src_unroll(const SizeArgs& args) const; bool need_src_unroll(const SizeArgs& args) const;
@@ -310,6 +376,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1)


private: private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
@@ -333,6 +400,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)


private: private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
@@ -354,6 +422,13 @@ public:


static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& dst_pg, TensorLayout& bias_pg); TensorLayout& dst_pg, TensorLayout& bias_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)

std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}


private: private:
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
@@ -370,10 +445,13 @@ public:
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
const char* name() const override { return "QUINT4x4x32_WMMA"; } const char* name() const override { return "QUINT4x4x32_WMMA"; }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }

private: private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const;
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const;
bool use_kernel_fhxfw(const SizeArgs& args) const; bool use_kernel_fhxfw(const SizeArgs& args) const;
size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const; size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const;
MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
}; };
#endif #endif


@@ -395,6 +473,7 @@ public:
const convolution::ConvParam& param, float alpha, float beta, const convolution::ConvParam& param, float alpha, float beta,
float gamma, float scale, cudaStream_t stream, float gamma, float scale, cudaStream_t stream,
param::ConvBias::NonlineMode nonlinear_mode); param::ConvBias::NonlineMode nonlinear_mode);
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8)
}; };


class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final
@@ -415,8 +494,9 @@ public:
warp_k == 32 && stage == 2) { warp_k == 32 && stage == 2) {
return ""; return "";
} }
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n,
threadblock_k, warp_m, warp_n, warp_k, stage);
return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m,
threadblock_n, threadblock_k, warp_m, warp_n,
warp_k, stage);
} }
}; };
AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
@@ -433,6 +513,13 @@ public:
SmallVector<TensorLayout> deduce_preprocessed_filter_layout( SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const override; const SizeArgs& args) const override;
void exec_preprocess(const ExecArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)

std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}


private: private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
@@ -457,9 +544,7 @@ public:
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
template <typename BiasVisitor> template <typename BiasVisitor>
static void dispatch_nonlinear_mode( static void dispatch_nonlinear_mode(
@@ -471,6 +556,14 @@ public:
MMATileSize mma_tile_size); MMATileSize mma_tile_size);
static std::string to_string(MMATileSize mma_tile_size); static std::string to_string(MMATileSize mma_tile_size);


MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8)

std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}

private: private:
MMATileSize m_mma_tile_size; MMATileSize m_mma_tile_size;
std::string m_name; std::string m_name;
@@ -488,10 +581,16 @@ public:
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8)

std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}

private: private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
const SizeArgs& args) const; const SizeArgs& args) const;
@@ -513,6 +612,13 @@ public:
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8)

std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}


private: private:
MMATileSize m_mma_tile_size; MMATileSize m_mma_tile_size;
@@ -533,6 +639,13 @@ public:
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8)

std::string param() const override {
std::string ret;
serialize_write_pod(m_mma_tile_size, ret);
return ret;
}


private: private:
MMATileSize m_mma_tile_size; MMATileSize m_mma_tile_size;
@@ -570,6 +683,13 @@ public:
SmallVector<TensorLayout> deduce_preprocessed_filter_layout( SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const SizeArgs& args) const override; const SizeArgs& args) const override;
void exec_preprocess(const ExecArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_IMMA_NCHW32_INT8)

std::string param() const override {
std::string ret;
serialize_write_pod(m_algo_param, ret);
return ret;
}


private: private:
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
@@ -592,6 +712,14 @@ public:


bool is_reproducible() const override { return m_impl->is_reproducible(); } bool is_reproducible() const override { return m_impl->is_reproducible(); }


MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)

std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}

private: private:
SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr,
TensorLayout& fsrc, TensorLayout& ffilter, TensorLayout& fsrc, TensorLayout& ffilter,
@@ -603,17 +731,16 @@ private:
}; };




class ConvBiasForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;


public: public:
AlgoPack(); AlgoPack();


std::vector<AlgoBase*> all_algos, std::vector<AlgoBase*> all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos,
bfloat16_algos;
non_cudnn_algos, bfloat16_algos;
std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations;
std::vector<AlgoCUDNNConv> cudnn_convs; std::vector<AlgoCUDNNConv> cudnn_convs;
AlgoChanwise chanwise; AlgoChanwise chanwise;
@@ -646,6 +773,8 @@ public:


AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo); AlgoBase* cudnn_conv_from_enum(cudnnConvolutionFwdAlgo_t algo);


const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }

private: private:
#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000
void fill_imma_algos(); void fill_imma_algos();


+ 2
- 2
dnn/src/cuda/conv_bias/bfloat16.cpp View File

@@ -47,7 +47,7 @@ ConvBiasForwardImpl::AlgoBFloat16::float_args(
change_dtype(fdst); change_dtype(fdst);
opr->param() = args.opr->param(); opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT; opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_impl};
opr->execution_policy() = {m_impl->info()};
return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst);
} }


@@ -110,7 +110,7 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
auto convbias_opr = args.handle->create_operator<ConvBias>(); auto convbias_opr = args.handle->create_operator<ConvBias>();
convbias_opr->param() = args.opr->param(); convbias_opr->param() = args.opr->param();
convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
convbias_opr->execution_policy() = {m_impl};
convbias_opr->execution_policy() = {m_impl->info()};
convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor,
fdst_tensor, nullptr, cvter.workspace()); fdst_tensor, nullptr, cvter.workspace());
} }


+ 2
- 2
dnn/src/cuda/conv_bias/opr_impl.cpp View File

@@ -63,12 +63,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
auto conv_args = args; auto conv_args = args;


auto cudnn_conv_bias_act_from_enum_wrapper = auto cudnn_conv_bias_act_from_enum_wrapper =
[this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
[](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo); return sm_algo_pack.cudnn_conv_bias_act_from_enum(algo);
}; };


auto cudnn_conv_from_enum_wrapper = auto cudnn_conv_from_enum_wrapper =
[this](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
[](cudnnConvolutionFwdAlgo_t algo) -> AlgoBase* {
return sm_algo_pack.cudnn_conv_from_enum(algo); return sm_algo_pack.cudnn_conv_from_enum(algo);
}; };




+ 14
- 11
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -24,17 +24,6 @@ public:
_megdnn_tensor_out dst, _megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter, const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
@@ -80,6 +69,20 @@ public:


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }


static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& bias,
const TensorLayout& z,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;

private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };


+ 6
- 0
dnn/src/cuda/convolution/backward_data/algo.cpp View File

@@ -52,8 +52,14 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
all_algos.push_back(bfloat16_refhold.back().get()); all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get()); bfloat16_algos.push_back(bfloat16_refhold.back().get());
} }

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl)

ConvolutionBackwardDataImpl::AlgoCUDNN* ConvolutionBackwardDataImpl::AlgoCUDNN*
ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum( ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdDataAlgo_t algo) { cudnnConvolutionBwdDataAlgo_t algo) {


+ 163
- 157
dnn/src/cuda/convolution/backward_data/algo.h View File

@@ -11,8 +11,11 @@


#pragma once #pragma once


#include "src/cuda/convolution/helper.h"
#include <unordered_map> #include <unordered_map>
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/convolution/helper.h"
#include "src/cuda/cudnn_wrapper.h"


namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
@@ -23,154 +26,146 @@ namespace cuda {
* All the algo impls should try to support non-contiguous batch dim, for group * All the algo impls should try to support non-contiguous batch dim, for group
* conv execution. * conv execution.
*/ */
class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;

public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout, *filter_layout;
ConvolutionBackwardDataImpl *opr;

std::string to_string() const;
void init_desc(convolution::CUDNNBwdDataDescs &desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad);
SizeArgs(ConvolutionBackwardDataImpl* opr,
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad);

convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_layout, filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(ConvolutionBackwardDataImpl *opr,
_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;

bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;


bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}

AlgoBase& check_workspace(
const SizeArgs &args, const Workspace &workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
public:
enum class AlgoType : uint32_t {
CUDA_CUDNN,
CUDA_MATMUL,
CUDA_CHANWISE,
CUDA_CHANWISE_SMALL,
CUDA_BFLOAT16,
CUDA_GROUP_CONV_GENERAL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl* handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout, *filter_layout;
ConvolutionBackwardDataImpl* opr;

std::string to_string() const;
void init_desc(convolution::CUDNNBwdDataDescs& desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
} }

virtual bool is_cudnn() const {
return false;
SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad);

convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_layout, filter_meta,
diff_layout};
} }
};
struct ExecArgs : public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;

bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}

bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}

AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}

virtual bool is_cudnn() const { return false; }
}; };


class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase { class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;


public:
public:
AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_bwd_data_algos().end());
m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum);
}


AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdDataAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }


bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }


const char* name() const override {
return m_name;
}
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }


cudnnConvolutionBwdDataAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)


bool is_cudnn() const override {
return true;
}
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
}; };


//! im2col and matmul, with dilation //! im2col and matmul, with dilation
class ConvolutionBackwardDataImpl::AlgoMatmul final: public AlgoBase {
template<typename T>
static void exec_internal(const ExecArgs &args);
class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase {
template <typename T>
static void exec_internal(const ExecArgs& args);


public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return "MATMUL";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };


class ConvolutionBackwardDataImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
}; };


class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return "CHANNEL_WISE_SMALL";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE_SMALL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
}; };


class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
@@ -190,61 +185,72 @@ private:
TensorLayout& fsrc, TensorLayout& ffilter, TensorLayout& fsrc, TensorLayout& ffilter,
TensorLayout& fdst) const; TensorLayout& fdst) const;
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)

std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}
}; };


//! implement group conv by another algo //! implement group conv by another algo
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name; std::string m_name;


public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }


bool is_reproducible() const override {
return m_impl->is_reproducible();
}
bool is_reproducible() const override { return m_impl->is_reproducible(); }

static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
TensorLayout& grad_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)


static void modify_size_args(SizeArgs &args,
TensorLayout &diff_pg, TensorLayout &grad_pg);
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
}; };


class ConvolutionBackwardDataImpl::AlgoPack {
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp // defined in cudnn.cpp
void fill_cudnn_algos(); void fill_cudnn_algos();


AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;


public:
AlgoPack();
public:
AlgoPack();


std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
AlgoChanwiseSmall chanwise_small;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
AlgoChanwiseSmall chanwise_small;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;


std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms //! all algorithms
all_algos, all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos,
bfloat16_algos;
non_cudnn_algos, bfloat16_algos;

AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);


AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 2
dnn/src/cuda/convolution/backward_data/bfloat16.cpp View File

@@ -42,7 +42,7 @@ ConvolutionBackwardDataImpl::AlgoBFloat16::float_args(
change_dtype(fgrad); change_dtype(fgrad);
opr->param() = args.opr->param(); opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT; opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
opr->execution_policy() = {m_algorithm->info()};
return SizeArgs(opr, ffilter, fdiff, fgrad); return SizeArgs(opr, ffilter, fdiff, fgrad);
} }


@@ -105,7 +105,7 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec(
args.handle->create_operator<ConvolutionBackwardData>(); args.handle->create_operator<ConvolutionBackwardData>();
conv_back_data_opr->param() = args.opr->param(); conv_back_data_opr->param() = args.opr->param();
conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
conv_back_data_opr->execution_policy() = {m_algorithm};
conv_back_data_opr->execution_policy() = {m_algorithm->info()};
conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace()); cvter.workspace());
} }


+ 3
- 29
dnn/src/cuda/convolution/backward_data/cudnn.cpp View File

@@ -98,35 +98,9 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec(
} }


void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() { void ConvolutionBackwardDataImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)

#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})

DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true);
#if CUDNN_MAJOR >= 5
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true);
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true);
#endif
#endif

#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif

#undef DEF_ALGO

#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv_bwd_data_algos()) {
cudnn.push_back(algo.first);
}
} }


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 6
- 0
dnn/src/cuda/convolution/backward_filter/algo.cpp View File

@@ -49,8 +49,14 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() {
all_algos.push_back(bfloat16_refhold.back().get()); all_algos.push_back(bfloat16_refhold.back().get());
bfloat16_algos.push_back(bfloat16_refhold.back().get()); bfloat16_algos.push_back(bfloat16_refhold.back().get());
} }

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl)

ConvolutionBackwardFilterImpl::AlgoCUDNN* ConvolutionBackwardFilterImpl::AlgoCUDNN*
ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum( ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdFilterAlgo_t algo) { cudnnConvolutionBwdFilterAlgo_t algo) {


+ 155
- 147
dnn/src/cuda/convolution/backward_filter/algo.h View File

@@ -6,13 +6,16 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once


#include "src/cuda/convolution/helper.h"
#include <unordered_map> #include <unordered_map>
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/convolution/helper.h"


namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
@@ -23,141 +26,134 @@ namespace cuda {
* All the algo impls should try to support non-contiguous batch dim, for group * All the algo impls should try to support non-contiguous batch dim, for group
* conv execution. * conv execution.
*/ */
class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;

public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout, *grad_layout;
CanonizedFilterMeta grad_filter_meta;
ConvolutionBackwardFilterImpl *opr;

std::string to_string() const;
void init_desc(convolution::CUDNNBwdFilterDescs &desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta,
opr->param());
}
SizeArgs(ConvolutionBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(ConvolutionBackwardFilterImpl* opr,
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta);

convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_layout, grad_filter_meta,
diff_layout};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(ConvolutionBackwardFilterImpl *opr,
_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;

bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
class ConvolutionBackwardFilterImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;


bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
public:
enum class AlgoType : uint32_t {
CUDA_CUDNN,
CUDA_MATMUL,
CUDA_CHANWISE,
CUDA_BFLOAT16,
CUDA_GROUP_CONV_GENERAL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl* handle;
const TensorLayout *src_layout, *diff_layout, *grad_layout;
CanonizedFilterMeta grad_filter_meta;
ConvolutionBackwardFilterImpl* opr;

std::string to_string() const;
void init_desc(convolution::CUDNNBwdFilterDescs& desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param());
} }

AlgoBase& check_workspace(
const SizeArgs &args, const Workspace &workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}

virtual bool is_cudnn() const {
return false;
SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta);

convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_layout, grad_filter_meta,
diff_layout};
} }
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;

bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}

bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}

AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}

virtual bool is_cudnn() const { return false; }
}; };


class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;


public:
public:
AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv_bwd_flt_algos().end());
m_attr = CudnnAlgoPack::conv_bwd_flt_algos().at(cudnn_enum);
}


AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdFilterAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }


bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }


const char* name() const override {
return m_name;
}
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; }


cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }


bool is_cudnn() const override {
return true;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
}; };


//! im2col and matmul, with dilation //! im2col and matmul, with dilation
class ConvolutionBackwardFilterImpl::AlgoMatmul final: public AlgoBase {
template<typename T>
static void exec_internal(const ExecArgs &args);
class ConvolutionBackwardFilterImpl::AlgoMatmul final : public AlgoBase {
template <typename T>
static void exec_internal(const ExecArgs& args);


public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return "MATMUL";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };


class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
}; };


class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase {
@@ -169,6 +165,13 @@ public:


const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)

std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}


private: private:
std::string m_name; std::string m_name;
@@ -180,57 +183,62 @@ private:
}; };


//! implement group conv by another algo //! implement group conv by another algo
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name; std::string m_name;


public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);

bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return m_name.c_str(); }


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_impl->is_reproducible(); }


const char* name() const override {
return m_name.c_str();
}
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& diff_pg);


bool is_reproducible() const override {
return m_impl->is_reproducible();
}
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)


static void modify_size_args(SizeArgs &args,
TensorLayout &src_pg, TensorLayout &diff_pg);
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
}; };


class ConvolutionBackwardFilterImpl::AlgoPack {
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp // defined in cudnn.cpp
void fill_cudnn_algos(); void fill_cudnn_algos();


AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;


public:
AlgoPack();
public:
AlgoPack();


std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;
std::vector<AlgoCUDNN> cudnn;
AlgoMatmul matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold;


std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms //! all algorithms
all_algos, all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos,
bfloat16_algos;
non_cudnn_algos, bfloat16_algos;

AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);


AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 2
dnn/src/cuda/convolution/backward_filter/bfloat16.cpp View File

@@ -42,7 +42,7 @@ ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args(
change_dtype(fgrad); change_dtype(fgrad);
opr->param() = args.opr->param(); opr->param() = args.opr->param();
opr->param().compute_mode = Param::ComputeMode::DEFAULT; opr->param().compute_mode = Param::ComputeMode::DEFAULT;
opr->execution_policy() = {m_algorithm};
opr->execution_policy() = {m_algorithm->info()};
return SizeArgs(opr, fsrc, fdiff, fgrad); return SizeArgs(opr, fsrc, fdiff, fgrad);
} }


@@ -107,7 +107,7 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec(
conv_back_filter_opr->param() = args.opr->param(); conv_back_filter_opr->param() = args.opr->param();
conv_back_filter_opr->param().compute_mode = conv_back_filter_opr->param().compute_mode =
Param::ComputeMode::DEFAULT; Param::ComputeMode::DEFAULT;
conv_back_filter_opr->execution_policy() = {m_algorithm};
conv_back_filter_opr->execution_policy() = {m_algorithm->info()};
conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor,
cvter.workspace()); cvter.workspace());
} }


+ 3
- 29
dnn/src/cuda/convolution/backward_filter/cudnn.cpp View File

@@ -80,35 +80,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec(
} }


void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)

#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})

DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false);
#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1)
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true);
#if CUDNN_MAJOR >= 6
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true);
#endif
#endif

#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif

#undef DEF_ALGO

#undef V
#undef V1
for(auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) {
cudnn.push_back(algo.first);
}
} }


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 1
- 10
dnn/src/cuda/convolution/opr_impl.cpp View File

@@ -70,7 +70,7 @@ ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src,
conv_param.dilate_w, conv_param.dilate_w,
0, 0,
conv_param.compute_mode}; conv_param.compute_mode};
ret.convbias_opr->execution_policy() = {this->execution_policy().algorithm};
ret.convbias_opr->execution_policy() = {this->execution_policy().algo};
return ret; return ret;
} }


@@ -183,15 +183,6 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
CUDNNBwdDataDescs desc; CUDNNBwdDataDescs desc;
args.init_desc(desc); args.init_desc(desc);


//disable, segfault in megbrain, need further investigate.
#if 0
bool is_heuristic_success= convolution::
PerformanceModelBackwardData::get_algo_backward_data_success(
args, desc, workspace_limit_in_bytes, &algo);
if (is_heuristic_success) {
return sm_algo_pack.cudnn_from_enum(algo);
}
#endif
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
int max_count = 0; int max_count = 0;
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(


+ 114
- 92
dnn/src/cuda/convolution/opr_impl.h View File

@@ -24,14 +24,6 @@ class ConvolutionForwardImpl: public ConvolutionForward {
const PreprocessedFilter* preprocessed_filter, const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;


std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
@@ -60,99 +52,129 @@ class ConvolutionForwardImpl: public ConvolutionForward {
TensorLayout bias_layout; TensorLayout bias_layout;
TensorLayout z_layout; TensorLayout z_layout;
}; };
private:
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&,
const TensorLayout&,
const TensorLayout&);
};


class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
public:
using ConvolutionBackwardData::ConvolutionBackwardData;
void exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; bool reproducible) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoGroupConvGeneral;
class AlgoBFloat16;

class AlgoPack;

static const AlgoPack& algo_pack() {
return sm_algo_pack;
}


private: private:
static AlgoPack sm_algo_pack;
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&,
const TensorLayout&,
const TensorLayout&);
}; };


class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter {
public:
using ConvolutionBackwardFilter::ConvolutionBackwardFilter;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& gradk,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoBFloat16;

class AlgoPack;

static const AlgoPack& algo_pack() {
return sm_algo_pack;
}
class ConvolutionBackwardDataImpl : public ConvolutionBackwardData {
public:
using ConvolutionBackwardData::ConvolutionBackwardData;
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(
const TensorLayout& filter, const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff, const TensorLayout& grad,
size_t workspace_limit_in_bytes, bool reproducible) {
return get_algorithm_heuristic(filter, filter_meta, diff, grad,
workspace_limit_in_bytes, reproducible)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoGroupConvGeneral;
class AlgoBFloat16;

class AlgoPack;

static const AlgoPack& algo_pack() { return sm_algo_pack; }

static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const CanonizedFilterMeta& filter_meta,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible);

static AlgoPack sm_algo_pack;
};


private:
static AlgoPack sm_algo_pack;
class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter {
public:
using ConvolutionBackwardFilter::ConvolutionBackwardFilter;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) override;
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible) {
return get_algorithm_heuristic(src, diff, grad, grad_meta,
workspace_limit_in_bytes, reproducible)
->info();
}

const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoBFloat16;

class AlgoPack;

static const AlgoPack& algo_pack() { return sm_algo_pack; }

static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
const CanonizedFilterMeta& grad_meta,
size_t workspace_limit_in_bytes,
bool reproducible);

static AlgoPack sm_algo_pack;
}; };


} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 7
- 1
dnn/src/cuda/convolution3d/backward_data/algo.cpp View File

@@ -39,8 +39,14 @@ Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&i); all_algos.push_back(&i);
} }
megdnn_assert(all_algos_data == all_algos.data()); megdnn_assert(all_algos_data == all_algos.data());

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardDataImpl)

Convolution3DBackwardDataImpl::AlgoCUDNN* Convolution3DBackwardDataImpl::AlgoCUDNN*
Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum( Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdDataAlgo_t algo) { cudnnConvolutionBwdDataAlgo_t algo) {
@@ -96,7 +102,7 @@ std::string Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::to_string() const
fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2],
diff_layout->to_string().c_str(), diff_layout->to_string().c_str(),
grad_layout->to_string().c_str(), grad_layout->to_string().c_str(),
fm.padding[0], fm.padding[1], fm.padding[2],
fm.padding[0], fm.padding[1], fm.padding[2],
fm.stride[0], fm.stride[1], fm.stride[2], fm.stride[0], fm.stride[1], fm.stride[2],
fm.dilation[0], fm.dilation[1] ,fm.dilation[2], fm.dilation[0], fm.dilation[1] ,fm.dilation[2],
!fm.should_flip, !fm.should_flip,


+ 134
- 127
dnn/src/cuda/convolution3d/backward_data/algo.h View File

@@ -6,13 +6,16 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once


#include "src/cuda/convolution3d/helper.h"
#include <unordered_map> #include <unordered_map>
#include "src/cuda/convolution3d/helper.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"


namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
@@ -23,170 +26,174 @@ namespace cuda {
* All the algo impls should try to support non-contiguous batch dim, for group * All the algo impls should try to support non-contiguous batch dim, for group
* conv execution. * conv execution.
*/ */
class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;

public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout;
Convolution3DBackwardDataImpl *opr;

std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdDataDescs &desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
}
SizeArgs(Convolution3DBackwardDataImpl *opr,
const TensorLayout &filter, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(Convolution3DBackwardDataImpl *opr,
const CanonizedFilterMeta &filter, const TensorLayout &diff,
const TensorLayout &grad);

convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_meta, diff_layout,
opr->param().data_type};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(Convolution3DBackwardDataImpl *opr,
_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;

bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(
const SizeArgs &args, const Workspace &workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
class Convolution3DBackwardDataImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;

public:
enum class AlgoType : uint32_t {
CUDA_GROUP_CONV_GENERAL,
CUDA_CUDNN,
CUDA_CHANWISE,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;


AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl* handle;
CanonizedFilterMeta filter_meta;
const TensorLayout *diff_layout, *grad_layout;
Convolution3DBackwardDataImpl* opr;

std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdDataDescs& desc) const {
desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
} }

virtual bool is_cudnn() const {
return false;
SizeArgs(Convolution3DBackwardDataImpl* opr, const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(Convolution3DBackwardDataImpl* opr,
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad);

convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_meta, diff_layout,
opr->param().data_type};
} }
};
struct ExecArgs : public SizeArgs {
const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;

bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd data algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}

virtual bool is_cudnn() const { return false; }
}; };


class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase { class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdDataAlgo_t m_cudnn_enum; cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;


public:
public:
AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv3d_bwd_data_algos().end());
m_attr = CudnnAlgoPack::conv3d_bwd_data_algos().at(cudnn_enum);
}


AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdDataAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }


bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }


const char* name() const override {
return m_name;
}
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }


cudnnConvolutionBwdDataAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }


bool is_cudnn() const override {
return true;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)

std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
}; };


class Convolution3DBackwardDataImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class Convolution3DBackwardDataImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
}; };


//! implement group conv by another algo //! implement group conv by another algo
class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name; std::string m_name;


public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }


bool is_reproducible() const override {
return m_impl->is_reproducible();
}
bool is_reproducible() const override { return m_impl->is_reproducible(); }

static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
TensorLayout& grad_pg);


static void modify_size_args(SizeArgs &args,
TensorLayout &diff_pg, TensorLayout &grad_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
}; };


class Convolution3DBackwardDataImpl::AlgoPack {

class Convolution3DBackwardDataImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp // defined in cudnn.cpp
void fill_cudnn_algos(); void fill_cudnn_algos();


AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;


public:
AlgoPack();
public:
AlgoPack();


std::vector<AlgoCUDNN> cudnn;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<AlgoCUDNN> cudnn;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;


std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms //! all algorithms
all_algos, all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos; non_cudnn_algos;


AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 3
- 21
dnn/src/cuda/convolution3d/backward_data/cudnn.cpp View File

@@ -80,27 +80,9 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec(
} }


void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() { void Convolution3DBackwardDataImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)

#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})

DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true);
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif

#undef DEF_ALGO

#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv3d_bwd_data_algos()) {
cudnn.push_back(algo.first);
}
} }


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 9
- 3
dnn/src/cuda/convolution3d/backward_filter/algo.cpp View File

@@ -17,7 +17,7 @@ using namespace cuda;


Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&chanwise); non_cudnn_algos.push_back(&chanwise);
non_cudnn_algos.push_back(&inplace_matmul);
non_cudnn_algos.push_back(&inplace_matmul);
all_algos.push_back(&chanwise); // prefer chanwise all_algos.push_back(&chanwise); // prefer chanwise


fill_cudnn_algos(); fill_cudnn_algos();
@@ -41,8 +41,14 @@ Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() {
} }
megdnn_assert(all_algos_data == all_algos.data()); megdnn_assert(all_algos_data == all_algos.data());
non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul non_cudnn_algos.push_back(all_algos.rbegin()[0]); //group inplace_matmul

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardFilterImpl)

Convolution3DBackwardFilterImpl::AlgoCUDNN* Convolution3DBackwardFilterImpl::AlgoCUDNN*
Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum( Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionBwdFilterAlgo_t algo) { cudnnConvolutionBwdFilterAlgo_t algo) {
@@ -99,9 +105,9 @@ Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const {
"pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s",
src_layout->to_string().c_str(), src_layout->to_string().c_str(),
diff_layout->to_string().c_str(), diff_layout->to_string().c_str(),
fm.group, fm.ocpg, fm.icpg,
fm.group, fm.ocpg, fm.icpg,
fm.spatial[0], fm.spatial[1], fm.spatial[2], fm.spatial[0], fm.spatial[1], fm.spatial[2],
fm.padding[0], fm.padding[1], fm.padding[2],
fm.padding[0], fm.padding[1], fm.padding[2],
fm.stride[0], fm.stride[1], fm.stride[2], fm.stride[0], fm.stride[1], fm.stride[2],
fm.dilation[0], fm.dilation[1], fm.dilation[2], fm.dilation[0], fm.dilation[1], fm.dilation[2],
!fm.should_flip, !fm.should_flip,


+ 140
- 140
dnn/src/cuda/convolution3d/backward_filter/algo.h View File

@@ -6,198 +6,198 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once


#include "src/cuda/convolution3d/helper.h"
#include <unordered_map> #include <unordered_map>
#include "src/cuda/convolution3d/helper.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"


namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;

public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs {
HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout;
CanonizedFilterMeta grad_filter_meta;
Convolution3DBackwardFilterImpl *opr;

std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdFilterDescs &desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta,
opr->param());
}
SizeArgs(Convolution3DBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const TensorLayout &grad);
SizeArgs(Convolution3DBackwardFilterImpl *opr,
const TensorLayout &src, const TensorLayout &diff,
const CanonizedFilterMeta &grad);

convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_filter_meta, diff_layout,
opr->param().data_type};
}
};
struct ExecArgs: public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(Convolution3DBackwardFilterImpl *opr,
_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;

bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
class Convolution3DBackwardFilterImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;

public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
enum class AlgoType : uint32_t {
CUDA_GROUP_CONV_GENERAL,
CUDA_CUDNN,
CUDA_INPLACE_MATMUL,
CUDA_CHANWISE,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

struct SizeArgs {
HandleImpl* handle;
const TensorLayout *src_layout, *diff_layout;
CanonizedFilterMeta grad_filter_meta;
Convolution3DBackwardFilterImpl* opr;

std::string to_string() const;
void init_desc(convolution3d::CUDNNBwdFilterDescs& desc) const {
desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param());
} }
SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& grad);
SizeArgs(Convolution3DBackwardFilterImpl* opr, const TensorLayout& src,
const TensorLayout& diff, const CanonizedFilterMeta& grad);


virtual bool is_cudnn() const {
return false;
convolution3d::ForwardSizeArgs as_fwd_args() const {
return {handle, src_layout, grad_filter_meta, diff_layout,
opr->param().data_type};
} }
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *diff_tensor, *grad_tensor;
Workspace workspace;

ExecArgs(Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src,
_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;

bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv bwd filter algo %s: "
"required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}

virtual bool is_cudnn() const { return false; }
}; };


class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum; cudnnConvolutionBwdFilterAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;


public:
public:
AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum)
: m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv3d_bwd_flt_algos().end());
m_attr = CudnnAlgoPack::conv3d_bwd_flt_algos().at(cudnn_enum);
}


AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionBwdFilterAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }


bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }


const char* name() const override {
return m_name;
}
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; }


cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)


bool is_cudnn() const override {
return true;
}
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}
}; };


class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final
: public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;

const char* name() const override {
return "INPLACE_MATMUL";
}
bool is_reproducible() const override {
return false;
}
const char* name() const override { return "INPLACE_MATMUL"; }
bool is_reproducible() const override { return false; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
}; };


class Convolution3DBackwardFilterImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
class Convolution3DBackwardFilterImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
}; };


//! implement group conv by another algo //! implement group conv by another algo
class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final
: public AlgoBase {
AlgoBase* m_impl;
std::string m_name; std::string m_name;


public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }


bool is_reproducible() const override {
return m_impl->is_reproducible();
}
bool is_reproducible() const override { return m_impl->is_reproducible(); }


static void modify_size_args(SizeArgs &args,
TensorLayout &src_pg, TensorLayout &diff_pg);
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& diff_pg);

MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
}; };


class Convolution3DBackwardFilterImpl::AlgoPack {
class Convolution3DBackwardFilterImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp // defined in cudnn.cpp
void fill_cudnn_algos(); void fill_cudnn_algos();


AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;


public:
AlgoPack();
public:
AlgoPack();


std::vector<AlgoCUDNN> cudnn;
AlgoInplaceMatmul inplace_matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<AlgoCUDNN> cudnn;
AlgoInplaceMatmul inplace_matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;


std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms //! all algorithms
all_algos, all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos; non_cudnn_algos;


AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo);

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 3
- 23
dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp View File

@@ -66,29 +66,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec(
} }


void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)

#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({REPROD, \
#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V( \
CUDNN_PATCHLEVEL), \
NAME})

DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false);
#pragma message \
"fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc"
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false);

#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif

#undef DEF_ALGO

#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv3d_bwd_flt_algos()) {
cudnn.push_back(algo.first);
}
} }


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 11
- 5
dnn/src/cuda/convolution3d/forward/algo.cpp View File

@@ -21,13 +21,13 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&a1x1x1); non_cudnn_algos.push_back(&a1x1x1);


all_algos.push_back(&chanwise); all_algos.push_back(&chanwise);
fill_cudnn_algos(); fill_cudnn_algos();
for (auto &&i: cudnn) { for (auto &&i: cudnn) {
all_algos.push_back(&i);
all_algos.push_back(&i);
} }
all_algos.push_back(&inplace_matmul); all_algos.push_back(&inplace_matmul);
all_algos.push_back(&a1x1x1);
all_algos.push_back(&a1x1x1);
all_algos.reserve(all_algos.size() * 2); all_algos.reserve(all_algos.size() * 2);


// add gconv algos by AlgoGroupConvGeneral // add gconv algos by AlgoGroupConvGeneral
@@ -42,10 +42,16 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&i); all_algos.push_back(&i);
} }
megdnn_assert(all_algos_data == all_algos.data()); megdnn_assert(all_algos_data == all_algos.data());
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul
non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group inplace_matmul
non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1 non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1x1

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl)

Convolution3DForwardImpl::AlgoCUDNN* Convolution3DForwardImpl::AlgoCUDNN*
Convolution3DForwardImpl::AlgoPack::cudnn_from_enum( Convolution3DForwardImpl::AlgoPack::cudnn_from_enum(
cudnnConvolutionFwdAlgo_t algo) { cudnnConvolutionFwdAlgo_t algo) {
@@ -99,7 +105,7 @@ std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const {
"src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, " "src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, "
"pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s", "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, dtype=%s,%s",
src_layout->to_string().c_str(), src_layout->to_string().c_str(),
fm.group, fm.ocpg, fm.icpg,
fm.group, fm.ocpg, fm.icpg,
fm.spatial[0], fm.spatial[1], fm.spatial[2], fm.spatial[0], fm.spatial[1], fm.spatial[2],
dst_layout->to_string().c_str(), dst_layout->to_string().c_str(),
fm.padding[0], fm.padding[1], fm.padding[2], fm.padding[0], fm.padding[1], fm.padding[2],


+ 148
- 151
dnn/src/cuda/convolution3d/forward/algo.h View File

@@ -6,17 +6,20 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once


#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/utils.h"
#include "src/cuda/convolution3d/helper.h" #include "src/cuda/convolution3d/helper.h"
#include "src/cuda/handle.h"
#include "src/cuda/convolution3d/opr_impl.h" #include "src/cuda/convolution3d/opr_impl.h"
#include "src/common/utils.h"
#include "src/cuda/handle.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"


#include <unordered_map> #include <unordered_map>


@@ -29,195 +32,189 @@ namespace cuda {
* All the algo impls should try to support non-contiguous batch dim, for group * All the algo impls should try to support non-contiguous batch dim, for group
* conv execution. * conv execution.
*/ */
class Convolution3DForwardImpl::AlgoBase: public Algorithm {
protected:
~AlgoBase() = default;

public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs: public convolution3d::ForwardSizeArgs {
Convolution3DForwardImpl *opr;

std::string to_string() const;
void init_desc(convolution3d::CUDNNForwardDescs &desc) const {
desc.set(*src_layout, filter_meta, *dst_layout, opr->param());
}
SizeArgs(Convolution3DForwardImpl *opr,
const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst);
SizeArgs(Convolution3DForwardImpl *opr,
const TensorLayout &src,
const CanonizedFilterMeta &filter,
const TensorLayout &dst);
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *filter_tensor, *dst_tensor;
Workspace workspace;

ExecArgs(Convolution3DForwardImpl *opr,
_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs &args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs &args) const = 0;
virtual void exec(const ExecArgs &args) const = 0;

bool is_available_wk(const SizeArgs &args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(req <= workspace.size,
"conv3d fwd algo %s: required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}

virtual bool is_cudnn() const {
return false;
}
class Convolution3DForwardImpl::AlgoBase : public Algorithm {
protected:
~AlgoBase() = default;

public:
enum class AlgoType : uint32_t {
CUDA_1X1X1,
CUDA_GROUP_CONV_GENERAL,
CUDA_CUDNN,
CUDA_INPLACE_MATMUL,
CUDA_CHANWISE,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs : public convolution3d::ForwardSizeArgs {
Convolution3DForwardImpl* opr;

std::string to_string() const;
void init_desc(convolution3d::CUDNNForwardDescs& desc) const {
desc.set(*src_layout, filter_meta, *dst_layout, opr->param());
}
SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& dst);
SizeArgs(Convolution3DForwardImpl* opr, const TensorLayout& src,
const CanonizedFilterMeta& filter, const TensorLayout& dst);
};
struct ExecArgs : public SizeArgs {
const TensorND *src_tensor, *filter_tensor, *dst_tensor;
Workspace workspace;

ExecArgs(Convolution3DForwardImpl* opr, _megdnn_tensor_in src,
_megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;

bool is_available_wk(const SizeArgs& args, size_t limit) {
return is_available(args) && get_workspace_in_bytes(args) <= limit;
}
bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit);
}
AlgoBase& check_workspace(const SizeArgs& args,
const Workspace& workspace) {
auto req = get_workspace_in_bytes(args);
megdnn_assert(
req <= workspace.size,
"conv3d fwd algo %s: required workspace %zu bytes, got %zu",
name(), req, workspace.size);
return *this;
}

virtual bool is_cudnn() const { return false; }
}; };
class Convolution3DForwardImpl::Algo1x1x1 final: public AlgoBase {
static void extract_matmul_layouts(const SizeArgs &args,
TensorLayout &A, TensorLayout &B, TensorLayout &C);
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;

const char* name() const override {
return "1x1x1";
}
bool is_reproducible() const override {
return true;
}
class Convolution3DForwardImpl::Algo1x1x1 final : public AlgoBase {
static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A,
TensorLayout& B, TensorLayout& C);

public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return "1x1x1"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1)
}; };


//! implement group conv by another algo //! implement group conv by another algo
class Convolution3DForwardImpl::AlgoGroupConvGeneral final: public AlgoBase {
AlgoBase *m_impl;
class Convolution3DForwardImpl::AlgoGroupConvGeneral final : public AlgoBase {
AlgoBase* m_impl;
std::string m_name; std::string m_name;


public:
AlgoGroupConvGeneral(AlgoBase *impl);
public:
AlgoGroupConvGeneral(AlgoBase* impl);


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


const char* name() const override {
return m_name.c_str();
}
const char* name() const override { return m_name.c_str(); }


bool is_reproducible() const override {
return m_impl->is_reproducible();
}
bool is_reproducible() const override { return m_impl->is_reproducible(); }


static void modify_size_args(SizeArgs &args,
TensorLayout &src_pg, TensorLayout &dst_pg);
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg,
TensorLayout& dst_pg);
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
std::string param() const override {
std::string ret;
serialize_write_pod(m_impl, ret);
return ret;
}
}; };


class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase {
bool m_is_reproducible;
const char *m_name;
cudnnConvolutionFwdAlgo_t m_cudnn_enum; cudnnConvolutionFwdAlgo_t m_cudnn_enum;
CudnnAlgoPack::Attr m_attr;


public:
public:
AlgoCUDNN(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) {
megdnn_assert(CudnnAlgoPack::conv3d_fwd_algos().find(cudnn_enum) !=
CudnnAlgoPack::conv3d_fwd_algos().end());
m_attr = CudnnAlgoPack::conv3d_fwd_algos().at(cudnn_enum);
}


AlgoCUDNN(bool is_reproducible, const char *name,
cudnnConvolutionFwdAlgo_t cudnn_enum):
m_is_reproducible(is_reproducible),
m_name(name),
m_cudnn_enum(cudnn_enum)
{}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
bool is_reproducible() const override { return m_attr.is_reproducible; }


bool is_reproducible() const override {
return m_is_reproducible;
}
const char* name() const override { return m_attr.name.c_str(); }


const char* name() const override {
return m_name;
}
cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; }


cudnnConvolutionFwdAlgo_t cudnn_enum() const {
return m_cudnn_enum;
}
bool is_cudnn() const override { return true; }


bool is_cudnn() const override {
return true;
}
};
MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)


class Convolution3DForwardImpl::AlgoInplaceMatmul final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
std::string param() const override {
std::string ret;
serialize_write_pod(m_cudnn_enum, ret);
return ret;
}


const char* name() const override {
return "INPLACE_MATMUL";
}
bool is_reproducible() const override {
return true;
}
}; };


class Convolution3DForwardImpl::AlgoInplaceMatmul final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;


class Convolution3DForwardImpl::AlgoChanwise final: public AlgoBase {
public:
bool is_available(const SizeArgs &args) const override;
size_t get_workspace_in_bytes(const SizeArgs &args) const override;
void exec(const ExecArgs &args) const override;
const char* name() const override { return "INPLACE_MATMUL"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL)
};


const char* name() const override {
return "CHANNEL_WISE";
}
bool is_reproducible() const override {
return true;
}
class Convolution3DForwardImpl::AlgoChanwise final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override { return "CHANNEL_WISE"; }
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
}; };


class Convolution3DForwardImpl::AlgoPack {
class Convolution3DForwardImpl::AlgoPack : NonCopyableObj {
// defined in cudnn.cpp // defined in cudnn.cpp
void fill_cudnn_algos(); void fill_cudnn_algos();


AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator = (const AlgoPack &) = delete;
AlgoBase::Mapper m_all_algos_map;


public:
AlgoPack();
public:
AlgoPack();


std::vector<AlgoCUDNN> cudnn;
Algo1x1x1 a1x1x1;
AlgoInplaceMatmul inplace_matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
std::vector<AlgoCUDNN> cudnn;
Algo1x1x1 a1x1x1;
AlgoInplaceMatmul inplace_matmul;
AlgoChanwise chanwise;
std::vector<AlgoGroupConvGeneral> gconv;
std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;


std::vector<AlgoBase*>
std::vector<AlgoBase*>
//! all algorithms //! all algorithms
all_algos, all_algos,
//! non-cudnn algos, used for heuristic if cudnn is not supported //! non-cudnn algos, used for heuristic if cudnn is not supported
non_cudnn_algos; non_cudnn_algos;


AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo);
AlgoCUDNN* cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo);

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 3
- 23
dnn/src/cuda/convolution3d/forward/cudnn.cpp View File

@@ -78,30 +78,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec(
cudnnGetErrorString(status), args.to_string().c_str()); cudnnGetErrorString(status), args.to_string().c_str());
} }



void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() { void Convolution3DForwardImpl::AlgoPack::fill_cudnn_algos() {
#define V1(v) #v
#define V(v) V1(v)

#define DEF_ALGO(NAME, REPROD) \
cudnn.push_back({ \
REPROD, #NAME \
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) \
"." V(CUDNN_PATCHLEVEL), \
NAME})

DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true);
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true);

#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif

#undef DEF_ALGO

#undef V
#undef V1
for (auto&& algo : CudnnAlgoPack::conv3d_fwd_algos()) {
cudnn.push_back(algo.first);
}
} }


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 147
- 117
dnn/src/cuda/convolution3d/opr_impl.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once


@@ -15,126 +16,155 @@
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


class Convolution3DForwardImpl: public Convolution3DForward {
public:
using Convolution3DForward::Convolution3DForward;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &filter,
const TensorLayout &dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const CanonizedFilterMeta& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) override;
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoCUDNN;
class Algo1x1x1;
class AlgoInplaceMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoPack;
static const AlgoPack& algo_pack() {
return sm_algo_pack;
}
private:
static AlgoPack sm_algo_pack;
class Convolution3DForwardImpl : public Convolution3DForward {
public:
using Convolution3DForward::Convolution3DForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src,
const CanonizedFilterMeta& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) {
return get_algorithm_heuristic(src, filter, dst,
workspace_limit_in_bytes, reproducible)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) override;
const char* get_algorithm_set_name() const override;
class AlgoBase;
class AlgoCUDNN;
class Algo1x1x1;
class AlgoInplaceMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;
class AlgoPack;
static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;

private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const CanonizedFilterMeta& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible);


static AlgoPack sm_algo_pack;
}; };


class Convolution3DBackwardDataImpl: public Convolution3DBackwardData {
public:
using Convolution3DBackwardData::Convolution3DBackwardData;
void exec(_megdnn_tensor_in filter,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &filter,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoInplaceMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;

class AlgoPack;

static const AlgoPack& algo_pack() {
return sm_algo_pack;
}

private:
static AlgoPack sm_algo_pack;
class Convolution3DBackwardDataImpl : public Convolution3DBackwardData {
public:
using Convolution3DBackwardData::Convolution3DBackwardData;
void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
AlgorithmInfo get_algorithm_info_heuristic(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, size_t workspace_limit_in_bytes,
bool reproducible) {
return get_algorithm_heuristic(filter, diff, grad,
workspace_limit_in_bytes, reproducible)
->info();
}
size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoInplaceMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;

class AlgoPack;

static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;

private:
Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible);

static AlgoPack sm_algo_pack;
}; };


class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter {
public:
using Convolution3DBackwardFilter::Convolution3DBackwardFilter;
void exec(_megdnn_tensor_in src,
_megdnn_tensor_in diff,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
std::vector<Algorithm *> get_all_algorithms(const TensorLayout &src,
const TensorLayout &diff,
const TensorLayout &grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes,
bool reproducible);
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) override;
const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoInplaceMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;

class AlgoPack;

static const AlgoPack& algo_pack() {
return sm_algo_pack;
}

private:
static AlgoPack sm_algo_pack;
class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter {
public:
using Convolution3DBackwardFilter::Convolution3DBackwardFilter;
void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad) override;
AlgorithmInfo get_algorithm_info_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes,
bool reproducible) {
return get_algorithm_heuristic(src, diff, grad,
workspace_limit_in_bytes, reproducible)
->info();
}

const char* get_algorithm_set_name() const override;

class AlgoBase;
class AlgoCUDNN;
class AlgoInplaceMatmul;
class AlgoChanwise;
class AlgoGroupConvGeneral;

class AlgoPack;

static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;

private:
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const CanonizedFilterMeta& grad,
size_t workspace_limit_in_bytes,
bool reproducible);

static AlgoPack sm_algo_pack;
}; };
} // namespace cuda
} // namespace megdnn
} // namespace cuda
} // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 131
- 0
dnn/src/cuda/cudnn_wrapper.cpp View File

@@ -433,6 +433,137 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) {
desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT));
} }


////////////////////////// CudnnAlgoPack //////////////////////////

#define V1(v) #v
#define V(v) V1(v)
#define DEF_NAME(NAME) \
#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL)
#define DEF_ALGO(NAME, PROD) \
{ \
NAME, { DEF_NAME(NAME), PROD } \
}

#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1)
#pragma message "not latest cudnn"
#endif

const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr>
CudnnAlgoPack::conv_bwd_data_algos() {
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t,
CudnnAlgoPack::Attr>
algos = {
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true),
#if CUDNN_MAJOR >= 5
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true),
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED,
true),
#endif
#endif
};

return algos;
}

const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr>
CudnnAlgoPack::conv_bwd_flt_algos() {
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t,
CudnnAlgoPack::Attr>
algos = {
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false),
#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1)
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
true),
#if CUDNN_MAJOR >= 6
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true),
#endif
#endif

};

return algos;
}


const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr>
CudnnAlgoPack::conv_fwd_algos() {
static const std::unordered_map<cudnnConvolutionFwdAlgo_t,
CudnnAlgoPack::Attr>
algos = {
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true),
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
true),
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true),
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true),
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true),
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true),

#if CUDNN_MAJOR >= 5
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true),
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true),
#endif
#endif

};

return algos;
}

const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr>
CudnnAlgoPack::conv3d_bwd_data_algos() {
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t,
CudnnAlgoPack::Attr>
algos = {
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true),
};

return algos;
} // namespace cuda

const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr>
CudnnAlgoPack::conv3d_bwd_flt_algos() {
#pragma message \
"fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc"
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t,
CudnnAlgoPack::Attr>
algos = {
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true),
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false),
};

return algos;
}

const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr>
CudnnAlgoPack::conv3d_fwd_algos() {
static const std::unordered_map<cudnnConvolutionFwdAlgo_t,
CudnnAlgoPack::Attr>
algos = {
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true),
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
true),
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true),
};

return algos;
}

#undef DEF_ALGO
#undef DEF_NAME
#undef V
#undef V1

} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn




+ 47
- 3
dnn/src/cuda/cudnn_wrapper.h View File

@@ -10,6 +10,7 @@
*/ */
#pragma once #pragma once


#include <unordered_map>
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
#include "src/cuda/cudnn_with_check.h" #include "src/cuda/cudnn_with_check.h"
@@ -27,7 +28,7 @@ class TensorDesc {
public: public:
TensorDesc(); TensorDesc();
//! default layout is nchw //! default layout is nchw
void set(const TensorLayout& layout, const param::Convolution::Format =
void set(const TensorLayout& layout, const param::Convolution::Format =
param::Convolution::Format::NCHW); param::Convolution::Format::NCHW);
~TensorDesc(); ~TensorDesc();
cudnnTensorDescriptor_t desc; cudnnTensorDescriptor_t desc;
@@ -103,9 +104,52 @@ class Conv3DDesc {
cudnnConvolutionDescriptor_t desc; cudnnConvolutionDescriptor_t desc;
}; };


class CudnnAlgoPack {
public:
//! algorithm attr
struct Attr {
std::string name;
bool is_reproducible;
};


static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr>
conv_bwd_data_algos();


} // namespace cuda
} // namespace megdnn
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr>
conv_bwd_flt_algos();

static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr>
conv_fwd_algos();

static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr>
conv3d_bwd_data_algos();

static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, Attr>
conv3d_bwd_flt_algos();

static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr>
conv3d_fwd_algos();

};

} // namespace cuda
} // namespace megdnn

namespace std {

#define DEF_HASH(_type) \
template <> \
struct hash<_type> { \
std::size_t operator()(const _type& algo) const { \
return std::hash<uint32_t>()(static_cast<uint32_t>(algo)); \
} \
}

DEF_HASH(cudnnConvolutionBwdDataAlgo_t);
DEF_HASH(cudnnConvolutionBwdFilterAlgo_t);
DEF_HASH(cudnnConvolutionFwdAlgo_t);

#undef DEF_HASH
} // namespace std


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 5
- 0
dnn/src/cuda/deformable_conv/bwd_data/algo.cpp View File

@@ -19,7 +19,12 @@ using OprImpl = DeformableConvBackwardDataImpl;


OprImpl::AlgoPack::AlgoPack() { OprImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_matmul); all_algos.push_back(&algo_matmul);

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }
MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardDataImpl)


OprImpl::AlgoPack OprImpl::sm_algo_pack; OprImpl::AlgoPack OprImpl::sm_algo_pack;




+ 13
- 4
dnn/src/cuda/deformable_conv/bwd_data/algo.h View File

@@ -13,11 +13,15 @@


#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"


#include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h"


#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


@@ -26,6 +30,10 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_MATMUL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
DeformableConvBackwardDataImpl* opr; DeformableConvBackwardDataImpl* opr;
@@ -107,17 +115,18 @@ public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }


const char* name() const override { return "AlgoMatmul"; } const char* name() const override { return "AlgoMatmul"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };


class DeformableConvBackwardDataImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;

class DeformableConvBackwardDataImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map;
public: public:
AlgoPack(); AlgoPack();
AlgoMatmul algo_matmul; AlgoMatmul algo_matmul;
//! all algorithms //! all algorithms
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 4
- 0
dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp View File

@@ -20,7 +20,11 @@ using OprImpl = DeformableConvBackwardFilterImpl;


OprImpl::AlgoPack::AlgoPack() { OprImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_matmul); all_algos.push_back(&algo_matmul);
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }
MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardFilterImpl)


OprImpl::AlgoPack OprImpl::sm_algo_pack; OprImpl::AlgoPack OprImpl::sm_algo_pack;




+ 13
- 4
dnn/src/cuda/deformable_conv/bwd_flt/algo.h View File

@@ -13,11 +13,15 @@


#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"


#include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h"


#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


@@ -26,6 +30,11 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_MATMUL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
DeformableConvBackwardFilterImpl* opr; DeformableConvBackwardFilterImpl* opr;
@@ -97,18 +106,18 @@ public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }


const char* name() const override { return "AlgoMatmul"; } const char* name() const override { return "AlgoMatmul"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };


class DeformableConvBackwardFilterImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;

class DeformableConvBackwardFilterImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map;
public: public:
AlgoPack(); AlgoPack();


AlgoMatmul algo_matmul; AlgoMatmul algo_matmul;
//! all algorithms //! all algorithms
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 6
- 0
dnn/src/cuda/deformable_conv/fwd/algo.cpp View File

@@ -22,8 +22,14 @@ using OprImpl = DeformableConvForwardImpl;


OprImpl::AlgoPack::AlgoPack() { OprImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_matmul); all_algos.push_back(&algo_matmul);

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvForwardImpl)

OprImpl::AlgoPack OprImpl::sm_algo_pack; OprImpl::AlgoPack OprImpl::sm_algo_pack;


OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im,


+ 13
- 4
dnn/src/cuda/deformable_conv/fwd/algo.h View File

@@ -13,9 +13,13 @@


#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/deformable_conv/opr_impl.h" #include "src/cuda/deformable_conv/opr_impl.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"


#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


@@ -24,6 +28,11 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_MATMUL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
DeformableConvForwardImpl* opr; DeformableConvForwardImpl* opr;
@@ -92,17 +101,17 @@ public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }


const char* name() const override { return "AlgoMatmul"; } const char* name() const override { return "AlgoMatmul"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
}; };


class DeformableConvForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;

class DeformableConvForwardImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map;
public: public:
AlgoPack(); AlgoPack();
AlgoMatmul algo_matmul; AlgoMatmul algo_matmul;
//! all algorithms //! all algorithms
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 65
- 50
dnn/src/cuda/deformable_conv/opr_impl.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once


@@ -29,19 +30,6 @@ public:
const TensorLayout& mask, const TensorLayout& mask,
const TensorLayout& dst) override; const TensorLayout& dst) override;


std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& dst) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& filter,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& im, Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const CanonizedFilterMeta& filter, const CanonizedFilterMeta& filter,
const TensorLayout& offset, const TensorLayout& offset,
@@ -58,31 +46,35 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& dst) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& filter,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };


class DeformableConvBackwardFilterImpl: public DeformableConvBackwardFilter {
class DeformableConvBackwardFilterImpl : public DeformableConvBackwardFilter {
public: public:
using DeformableConvBackwardFilter::DeformableConvBackwardFilter; using DeformableConvBackwardFilter::DeformableConvBackwardFilter;


void exec(_megdnn_tensor_in im,_megdnn_tensor_in offset, _megdnn_tensor_in mask,
_megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad,
void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
_megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
_megdnn_tensor_out filter_grad,
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;


std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& filter_grad) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& out_grad,
const TensorLayout& filter_grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& im, Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& offset, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& mask,
@@ -91,9 +83,11 @@ public:
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible); bool reproducible);


size_t get_workspace_in_bytes(
const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& filter_grad) override;
size_t get_workspace_in_bytes(const TensorLayout& im,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& out_grad,
const TensorLayout& filter_grad) override;


const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


@@ -103,6 +97,21 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& im, const TensorLayout& offset,
const TensorLayout& mask, const TensorLayout& out_grad,
const TensorLayout& filter_grad) override;

Algorithm* get_algorithm_heuristic(const TensorLayout& im,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& out_grad,
const TensorLayout& filter_grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
@@ -118,19 +127,6 @@ public:
_megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad,
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;


std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad) override;

Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, bool reproducible) override;

Algorithm* get_algorithm_heuristic( Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const CanonizedFilterMeta& filter, const TensorLayout& im, const CanonizedFilterMeta& filter,
const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& offset, const TensorLayout& mask,
@@ -138,11 +134,14 @@ public:
const TensorLayout& offset_grad, const TensorLayout& mask_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, bool reproducible); size_t workspace_limit_in_bytes, bool reproducible);


size_t get_workspace_in_bytes(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad) override;
size_t get_workspace_in_bytes(const TensorLayout& im,
const TensorLayout& filter,
const TensorLayout& offset,
const TensorLayout& mask,
const TensorLayout& out_grad,
const TensorLayout& im_grad,
const TensorLayout& offset_grad,
const TensorLayout& mask_grad) override;


const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


@@ -152,6 +151,22 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad,
const TensorLayout& mask_grad) override;

Algorithm* get_algorithm_heuristic(
const TensorLayout& im, const TensorLayout& filter,
const TensorLayout& offset, const TensorLayout& mask,
const TensorLayout& out_grad, const TensorLayout& im_grad,
const TensorLayout& offset_grad, const TensorLayout& mask_grad,
size_t workspace_limit_in_bytes, bool reproducible) override;


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;


+ 6
- 0
dnn/src/cuda/local_share/backward_data/algo.cpp View File

@@ -18,8 +18,14 @@ using namespace cuda;
LocalShareBackwardDataImpl::AlgoPack::AlgoPack() { LocalShareBackwardDataImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&implicit_gemm); all_algos.push_back(&implicit_gemm);
all_algos.push_back(&batched_matmul); all_algos.push_back(&batched_matmul);

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardDataImpl)

LocalShareBackwardDataImpl::AlgoPack LocalShareBackwardDataImpl::sm_algo_pack; LocalShareBackwardDataImpl::AlgoPack LocalShareBackwardDataImpl::sm_algo_pack;


LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(


+ 16
- 3
dnn/src/cuda/local_share/backward_data/algo.h View File

@@ -13,10 +13,14 @@


#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/local_share/opr_impl.h" #include "src/cuda/local_share/opr_impl.h"


#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


@@ -25,6 +29,13 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_IMPLICIT_GEMM,
CUDA_BATCHED_MATMUL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;


AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
LocalShareBackwardDataImpl* opr; LocalShareBackwardDataImpl* opr;
@@ -77,6 +88,7 @@ public:
const char* name() const override { const char* name() const override {
return "LOCAL_SHARE_IMPLICIT_GEMM"; return "LOCAL_SHARE_IMPLICIT_GEMM";
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM)
}; };


class LocalShareBackwardDataImpl::AlgoBatchedMatMul final class LocalShareBackwardDataImpl::AlgoBatchedMatMul final
@@ -93,11 +105,11 @@ public:
const char* name() const override { const char* name() const override {
return "LOCAL_SHARE_BATCHED_MATMUL"; return "LOCAL_SHARE_BATCHED_MATMUL";
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
}; };


class LocalShareBackwardDataImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class LocalShareBackwardDataImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map;


public: public:
AlgoPack(); AlgoPack();
@@ -106,6 +118,7 @@ public:
AlgoBatchedMatMul batched_matmul; AlgoBatchedMatMul batched_matmul;


std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 6
- 0
dnn/src/cuda/local_share/backward_filter/algo.cpp View File

@@ -18,8 +18,14 @@ using namespace cuda;
LocalShareBackwardFilterImpl::AlgoPack::AlgoPack() { LocalShareBackwardFilterImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&implicit_gemm); all_algos.push_back(&implicit_gemm);
all_algos.push_back(&batched_matmul); all_algos.push_back(&batched_matmul);

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareBackwardFilterImpl)

LocalShareBackwardFilterImpl::AlgoPack LocalShareBackwardFilterImpl::sm_algo_pack; LocalShareBackwardFilterImpl::AlgoPack LocalShareBackwardFilterImpl::sm_algo_pack;


LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs(


+ 16
- 3
dnn/src/cuda/local_share/backward_filter/algo.h View File

@@ -13,10 +13,14 @@


#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/local_share/opr_impl.h" #include "src/cuda/local_share/opr_impl.h"


#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


@@ -25,6 +29,12 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_IMPLICIT_GEMM,
CUDA_BATCHED_MATMUL,
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
LocalShareBackwardFilterImpl* opr; LocalShareBackwardFilterImpl* opr;
@@ -75,6 +85,7 @@ public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }


const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM)
}; };


class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase { class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase {
@@ -88,11 +99,11 @@ public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }


const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
}; };


class LocalShareBackwardFilterImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class LocalShareBackwardFilterImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map;


public: public:
AlgoPack(); AlgoPack();
@@ -101,6 +112,8 @@ public:
AlgoBatchedMatMul batched_matmul; AlgoBatchedMatMul batched_matmul;


std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 6
- 0
dnn/src/cuda/local_share/forward/algo.cpp View File

@@ -19,8 +19,14 @@ LocalShareForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&batch_size_aware_chwn_small_image); all_algos.push_back(&batch_size_aware_chwn_small_image);
all_algos.push_back(&batch_size_aware_chwn); all_algos.push_back(&batch_size_aware_chwn);
all_algos.push_back(&batched_matmul); all_algos.push_back(&batched_matmul);

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareForwardImpl)

LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack; LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack;


LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o, LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o,


+ 17
- 3
dnn/src/cuda/local_share/forward/algo.h View File

@@ -14,9 +14,13 @@
#include "megdnn/oprs.h" #include "megdnn/oprs.h"


#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/local_share/opr_impl.h" #include "src/cuda/local_share/opr_impl.h"


#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {


@@ -25,6 +29,13 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_CHWN_BATCH_SIZE_AWARE,
CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE,
CUDA_BATCHED_MATMUL
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
LocalShareForwardImpl* opr; LocalShareForwardImpl* opr;
@@ -79,6 +90,7 @@ public:
const char* name() const override { const char* name() const override {
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE";
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE)
}; };


class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final
@@ -95,6 +107,7 @@ public:
const char* name() const override { const char* name() const override {
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE";
} }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE)
}; };


class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase { class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase {
@@ -108,11 +121,11 @@ public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }


const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL)
}; };


class LocalShareForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class LocalShareForwardImpl::AlgoPack : NonCopyableObj {
AlgoBase::Mapper m_all_algos_map;


public: public:
AlgoPack(); AlgoPack();
@@ -122,6 +135,7 @@ public:
AlgoBatchedMatMul batched_matmul; AlgoBatchedMatMul batched_matmul;


std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 32
- 24
dnn/src/cuda/local_share/opr_impl.h View File

@@ -23,14 +23,6 @@ public:
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst) override; const TensorLayout& dst) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


class AlgoBase; class AlgoBase;
@@ -41,7 +33,17 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);


protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) override;
private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
}; };
@@ -54,14 +56,6 @@ public:
size_t get_workspace_in_bytes(const TensorLayout& filter, size_t get_workspace_in_bytes(const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


class AlgoBase; class AlgoBase;
@@ -71,6 +65,17 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& filter, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& filter,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
@@ -84,14 +89,6 @@ public:
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& diff, const TensorLayout& diff,
const TensorLayout& grad) override; const TensorLayout& grad) override;
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;
const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


class AlgoBase; class AlgoBase;
@@ -101,6 +98,17 @@ public:
class AlgoPack; class AlgoPack;


static const AlgoPack& algo_pack() { return sm_algo_pack; } static const AlgoPack& algo_pack() { return sm_algo_pack; }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& src, const TensorLayout& diff,
const TensorLayout& grad) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& diff,
const TensorLayout& grad,
size_t workspace_limit_in_bytes,
bool reproducible) override;


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;


+ 8
- 0
dnn/src/cuda/matrix_mul/algos.cpp View File

@@ -11,6 +11,7 @@


#include "./algos.h" #include "./algos.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/common/algo_base.h"


#include <cuda.h> #include <cuda.h>
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
@@ -33,10 +34,16 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas);
all_algos.push_back(cublas_bfloat16.get()); all_algos.push_back(cublas_bfloat16.get());
#endif #endif

for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }


MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;


MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl)

MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o,
const TensorLayout& A, const TensorLayout& A,
const TensorLayout& B, const TensorLayout& B,
@@ -67,4 +74,5 @@ std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const {
m, k, k, n, m, n, param.transposeA, param.transposeB, m, k, k, n, m, n, param.transposeA, param.transposeB,
layout_a.stride[0], layout_b.stride[0], layout_c.stride[0])); layout_a.stride[0], layout_b.stride[0], layout_c.stride[0]));
} }

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 38
- 26
dnn/src/cuda/matrix_mul/algos.h View File

@@ -6,14 +6,18 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/matrix_mul/opr_impl.h" #include "src/cuda/matrix_mul/opr_impl.h"
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"


#include <unordered_map>
#include <cuda.h> #include <cuda.h>
#include <memory> #include <memory>
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
@@ -32,6 +36,15 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
enum class AlgoType : uint32_t {
CUDA_CUBLAS,
CUDA_WMMA_UINT4X4X32,
CUDA_CUBLASLT,
CUDA_NAIVE,
CUDA_BFLOAT16
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
MatrixMulForwardImpl* opr; MatrixMulForwardImpl* opr;
@@ -62,12 +75,12 @@ public:
virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0; virtual void exec(const ExecArgs& args) const = 0;


bool is_available_wk(const SizeArgs& args, size_t limit) {
bool is_available_wk(const SizeArgs& args, size_t limit) const {
return is_available(args) && get_workspace_in_bytes(args) <= limit; return is_available(args) && get_workspace_in_bytes(args) <= limit;
} }
bool is_available_reproducible( bool is_available_reproducible(
const SizeArgs& args, bool reproducible = true, const SizeArgs& args, bool reproducible = true,
size_t limit = std::numeric_limits<size_t>::max()) {
size_t limit = std::numeric_limits<size_t>::max()) const {
return (!reproducible || is_reproducible()) && return (!reproducible || is_reproducible()) &&
is_available_wk(args, limit); is_available_wk(args, limit);
} }
@@ -80,8 +93,6 @@ public:
name(), req, workspace.size); name(), req, workspace.size);
return *this; return *this;
} }


}; };


class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase { class MatrixMulForwardImpl::AlgoCuBlas final : public AlgoBase {
@@ -91,13 +102,10 @@ public:
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override { size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override {
return 0_z; return 0_z;
} }
const char* name() const override {
return "CUBLAS";
}
const char* name() const override { return "CUBLAS"; }
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override {
return true;
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
}; };


#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000
@@ -106,13 +114,10 @@ public:
AlgoUInt4x4x32WMMA() = default; AlgoUInt4x4x32WMMA() = default;
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override {
return "UINT4x4x32_WMMA";
}
const char* name() const override { return "UINT4x4x32_WMMA"; }
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override {
return true;
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32)
}; };
#endif #endif
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
@@ -120,13 +125,10 @@ class MatrixMulForwardImpl::AlgoCuBlasLt final : public AlgoBase {
public: public:
bool is_available(const SizeArgs& args) const override; bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override {
return "CUBLAS_LT";
}
const char* name() const override { return "CUBLAS_LT"; }
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override {
return true;
}
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT)
}; };
#endif #endif


@@ -140,6 +142,7 @@ public:
const char* name() const override { return "NAIVE"; } const char* name() const override { return "NAIVE"; }
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE)
}; };


#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
@@ -151,6 +154,13 @@ public:
const char* name() const override { return m_name.c_str(); } const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE)

std::string param() const override {
std::string ret;
serialize_write_pod(m_algorithm, ret);
return ret;
}


private: private:
MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr;
@@ -160,9 +170,9 @@ private:
}; };
#endif #endif


class MatrixMulForwardImpl::AlgoPack {
AlgoPack(const AlgoPack&) = delete;
AlgoPack& operator=(const AlgoPack&) = delete;
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;


public: public:
AlgoPack(); AlgoPack();
@@ -178,6 +188,8 @@ public:
std::unique_ptr<AlgoBFloat16> cublas_bfloat16; std::unique_ptr<AlgoBFloat16> cublas_bfloat16;
#endif #endif
std::vector<AlgoBase*> all_algos; std::vector<AlgoBase*> all_algos;

const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


} // namespace cuda } // namespace cuda


+ 1
- 1
dnn/src/cuda/matrix_mul/bfloat16.cpp View File

@@ -82,7 +82,7 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const {
args.opr->handle()->create_operator<MatrixMulForward>(); args.opr->handle()->create_operator<MatrixMulForward>();
matmul_opr->param() = args.opr->param(); matmul_opr->param() = args.opr->param();
matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT;
matmul_opr->execution_policy() = {m_algorithm};
matmul_opr->execution_policy() = {m_algorithm->info()};
matmul_opr->exec(a, b, c, ctypecvt.workspace()); matmul_opr->exec(a, b, c, ctypecvt.workspace());
} }
ctypecvt.comp_to_dst_type(c, args.tensor_c); ctypecvt.comp_to_dst_type(c, args.tensor_c);


+ 11
- 9
dnn/src/cuda/matrix_mul/opr_impl.h View File

@@ -25,15 +25,6 @@ public:


bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }


std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C,
size_t workspace_limit_in_bytes,
bool reproducible) override;

const char* get_algorithm_set_name() const override { const char* get_algorithm_set_name() const override {
return "CUDA MATMUL"; return "CUDA MATMUL";
} }
@@ -55,6 +46,17 @@ public:
static const AlgoPack& algo_pack() { static const AlgoPack& algo_pack() {
return sm_algo_pack; return sm_algo_pack;
} }
static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc);

protected:
std::vector<Algorithm*> get_all_algorithms(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout& C,
size_t workspace_limit_in_bytes,
bool reproducible) override;


private: private:
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;


+ 5
- 0
dnn/src/fallback/conv_bias/algos.cpp View File

@@ -10,10 +10,14 @@
*/ */


#include "src/fallback/conv_bias/algos.h" #include "src/fallback/conv_bias/algos.h"
#include "src/fallback/conv_bias/conv1x1/algos.h"
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h"
#include "src/fallback/conv_bias/im2col/algos.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/winograd/strategy.h" #include "src/fallback/conv_bias/winograd/strategy.h"
#include "src/naive/convolution/helper.h" #include "src/naive/convolution/helper.h"
#include "src/common/algo_base.h"


#include "midout.h" #include "midout.h"


@@ -176,6 +180,7 @@ void kern_default(const ConvBiasImpl::NCBKernParam& p) {
} // namespace } // namespace


MIDOUT_DECL(megdnn_fallback_naive) MIDOUT_DECL(megdnn_fallback_naive)

/* ======================= AlgoNaive ======================== */ /* ======================= AlgoNaive ======================== */


bool ConvBiasImpl::AlgoNaive::usable( bool ConvBiasImpl::AlgoNaive::usable(


+ 25
- 0
dnn/src/fallback/conv_bias/algos.h View File

@@ -36,6 +36,7 @@ public:
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::NAIVE}; return {support_data_type, AlgoCategory::NAIVE};
} }
MEGDNN_DECL_ALGO_TYPE(FB_NAIVE)
}; };


class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase {
@@ -59,6 +60,12 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
} }
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo, ret);
return ret;
}


private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
@@ -87,6 +94,12 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
} }
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo, ret);
return ret;
}


private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
@@ -115,6 +128,12 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
} }
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo, ret);
return ret;
}


private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
@@ -143,6 +162,12 @@ public:
ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
} }
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo, ret);
return ret;
}


private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;


+ 6
- 0
dnn/src/fallback/conv_bias/common.h View File

@@ -155,6 +155,12 @@ using BiasMode = ConvBiasForward::BiasMode;
const NCBKernSizeParam& param) const override; \ const NCBKernSizeParam& param) const override; \
ConvAlgoTypePack get_algo_type() const override { \ ConvAlgoTypePack get_algo_type() const override { \
return {_algo_data_type, AlgoCategory::WINOGRAD}; \ return {_algo_data_type, AlgoCategory::WINOGRAD}; \
} \
std::string param() const override { \
std::string ret; \
serialize_write_pod(m_matmul_algo, ret); \
serialize_write_pod(m_tile_size, ret); \
return ret; \
} \ } \
\ \
private: \ private: \


+ 7
- 0
dnn/src/fallback/conv_bias/conv1x1/algos.h View File

@@ -60,6 +60,13 @@ public:
return {m_matmul_algo->matmul_description().algo_type.data_type, return {m_matmul_algo->matmul_description().algo_type.data_type,
AlgoCategory::IM2COL}; AlgoCategory::IM2COL};
} }
MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8)
std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo, ret);
serialize_write_pod(m_oc_block_size, ret);
return ret;
}


protected: protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;


+ 1
- 0
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h View File

@@ -43,6 +43,7 @@ public:
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::IM2COL}; return {support_data_type, AlgoCategory::IM2COL};
} }
MEGDNN_DECL_ALGO_TYPE(FB_CONV1x1_GEMV)


protected: protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;


+ 8
- 0
dnn/src/fallback/conv_bias/im2col/algos.h View File

@@ -68,6 +68,14 @@ public:
return {m_matmul_algo->matmul_description().algo_type.data_type, return {m_matmul_algo->matmul_description().algo_type.data_type,
AlgoCategory::IM2COL}; AlgoCategory::IM2COL};
} }
MEGDNN_DECL_ALGO_TYPE(FB_IM2COL)

std::string param() const override {
std::string ret;
serialize_write_pod(m_matmul_algo, ret);
serialize_write_pod(m_ohw_tile_size, ret);
return ret;
}


private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;


+ 84
- 25
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -22,6 +22,14 @@
#include "src/naive/convolution/algorithms.h" #include "src/naive/convolution/algorithms.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"


#if MEGDNN_X86
#include "src/x86/conv_bias/opr_impl.h"
#elif MEGDNN_AARCH64
#include "src/aarch64/conv_bias/opr_impl.h"
#elif MEGDNN_ARMV7
#include "src/armv7/conv_bias/opr_impl.h"
#endif

#include <cstring> #include <cstring>


using namespace megdnn; using namespace megdnn;
@@ -65,17 +73,19 @@ void incr_ptr(T*& dst, ptrdiff_t delta) {
class ConvBiasImpl::AlgoPack : NonCopyableObj { class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoNaive algo_naive; AlgoNaive algo_naive;
SmallVector<std::unique_ptr<AlgoBase>> refhold; SmallVector<std::unique_ptr<AlgoBase>> refhold;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;


public: public:


AlgoPack() { AlgoPack() {
refhold.emplace_back(new AlgoConv1x1Gemv()); refhold.emplace_back(new AlgoConv1x1Gemv());
all_algos.emplace_back(refhold.back().get());
m_all_algos.emplace_back(refhold.back().get());


static CpuOprDelegationStorage<> storage; static CpuOprDelegationStorage<> storage;
auto matmul_opr = storage.get<MatrixMul>(); auto matmul_opr = storage.get<MatrixMul>();
auto&& matmul_algos =
static_cast<fallback::MatrixMulImpl*>(matmul_opr)->algo_pack();
auto&& matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
->get_all_packed_algo();
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
#if MEGDNN_X86 #if MEGDNN_X86
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may //! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
@@ -97,13 +107,13 @@ public:
refhold.emplace_back(new AlgoIm2col( refhold.emplace_back(new AlgoIm2col(
static_cast<MatrixMulImpl::AlgoBase*>(algo), static_cast<MatrixMulImpl::AlgoBase*>(algo),
ohw_tile_size)); ohw_tile_size));
all_algos.emplace_back(refhold.back().get());
m_all_algos.emplace_back(refhold.back().get());
} }
for (size_t oc_tile_size : {48, 24}) { for (size_t oc_tile_size : {48, 24}) {
refhold.emplace_back(new AlgoConv1x1( refhold.emplace_back(new AlgoConv1x1(
static_cast<MatrixMulImpl::AlgoBase*>(algo), static_cast<MatrixMulImpl::AlgoBase*>(algo),
oc_tile_size)); oc_tile_size));
all_algos.emplace_back(refhold.back().get());
m_all_algos.emplace_back(refhold.back().get());
} }
#endif #endif


@@ -113,26 +123,35 @@ public:
//! FIXME: I do not know a better way to do it. //! FIXME: I do not know a better way to do it.
refhold.emplace_back(new AlgoWinogradF32( refhold.emplace_back(new AlgoWinogradF32(
static_cast<MatrixMulImpl::AlgoBase*>(algo))); static_cast<MatrixMulImpl::AlgoBase*>(algo)));
all_algos.emplace_back(refhold.back().get());
m_all_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoWinogradF32_4x4( refhold.emplace_back(new AlgoWinogradF32_4x4(
static_cast<MatrixMulImpl::AlgoBase*>(algo))); static_cast<MatrixMulImpl::AlgoBase*>(algo)));
all_algos.emplace_back(refhold.back().get());
m_all_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoWinogradQS8( refhold.emplace_back(new AlgoWinogradQS8(
static_cast<MatrixMulImpl::AlgoBase*>(algo))); static_cast<MatrixMulImpl::AlgoBase*>(algo)));
all_algos.emplace_back(refhold.back().get());
m_all_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoWinogradQS8_8x8( refhold.emplace_back(new AlgoWinogradQS8_8x8(
static_cast<MatrixMulImpl::AlgoBase*>(algo))); static_cast<MatrixMulImpl::AlgoBase*>(algo)));
all_algos.emplace_back(refhold.back().get());
m_all_algos.emplace_back(refhold.back().get());
#endif #endif
} }
all_algos.emplace_back(&algo_naive);
m_all_algos.emplace_back(&algo_naive);

for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
} }
SmallVector<AlgoBase*> all_algos;
const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
}; };


SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack;
return sl_algo_pack.all_algos;
const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() {
static AlgoPack algo_pack;
return algo_pack;
}

SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() {
return algo_pack().all_algos();
} }


SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type(
@@ -140,7 +159,7 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type(
megdnn_assert(nr_type_contain(target_type.data_type), megdnn_assert(nr_type_contain(target_type.data_type),
"ConvBias algo selection only support one type"); "ConvBias algo selection only support one type");
SmallVector<ConvBiasImpl::AlgoBase*> algos; SmallVector<ConvBiasImpl::AlgoBase*> algos;
for (auto&& algo : algo_pack()) {
for (auto&& algo : get_all_packed_algo()) {
auto algo_type = algo->get_algo_type(); auto algo_type = algo->get_algo_type();
if (contain_data_type(algo_type.data_type, target_type.data_type) && if (contain_data_type(algo_type.data_type, target_type.data_type) &&
algo_type.algo_category == target_type.algo_category) { algo_type.algo_category == target_type.algo_category) {
@@ -166,7 +185,7 @@ void ConvBiasImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
workspace.size, preprocessed_filter); workspace.size, preprocessed_filter);
auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace,
preprocessed_filter); preprocessed_filter);
ConvBiasImpl::Algorithm* algo = get_algorithm(fparam, workspace.size);
auto&& algo = get_algorithm(fparam, workspace.size);
if (!is_naive_algo(algo) && if (!is_naive_algo(algo) &&
NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) { NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
exec_with_ncb_kern(fparam, algo); exec_with_ncb_kern(fparam, algo);
@@ -189,9 +208,10 @@ void ConvBiasImpl::exec_preprocess(const TensorLayout& src_layout,
auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace, auto fparam = make_ncb_kern_param(src, filter, bias, dst, workspace,
preprocessed_filter); preprocessed_filter);
//! should not pass workspace_size limit otherwise can not find match algo //! should not pass workspace_size limit otherwise can not find match algo
ConvBiasImpl::Algorithm* algo = get_algorithm(fparam);
if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_preprocess_workspace, algo,
fparam) <= workspace.size) {
auto&& algo = get_algorithm(fparam);
if (!is_naive_algo(algo) &&
NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <=
workspace.size) {
exec_preprocess_with_ncb_kern(fparam, algo); exec_preprocess_with_ncb_kern(fparam, algo);
} else { } else {
naive::ConvBiasForwardImpl::exec_preprocess( naive::ConvBiasForwardImpl::exec_preprocess(
@@ -207,7 +227,7 @@ size_t ConvBiasImpl::get_workspace_in_bytes(
const PreprocessedFilter* preprocessed_filter) { const PreprocessedFilter* preprocessed_filter) {
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, auto fparam = make_ncb_kern_size_param(src, filter, bias, dst,
preprocessed_filter); preprocessed_filter);
ConvBiasImpl::Algorithm* algo = get_algorithm(fparam);
auto&& algo = get_algorithm(fparam);
if (is_naive_algo(algo)) { if (is_naive_algo(algo)) {
return naive::ConvBiasForwardImpl::get_workspace_in_bytes( return naive::ConvBiasForwardImpl::get_workspace_in_bytes(
src, filter, bias, z, dst, preprocessed_filter); src, filter, bias, z, dst, preprocessed_filter);
@@ -221,7 +241,7 @@ size_t ConvBiasImpl::get_preprocess_workspace_in_bytes(
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) { const TensorLayout& dst) {
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
Algorithm* algo = get_algorithm(fparam);
auto&& algo = get_algorithm(fparam);
if (is_naive_algo(algo)) { if (is_naive_algo(algo)) {
return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes( return naive::ConvBiasForwardImpl::get_preprocess_workspace_in_bytes(
src, filter, bias, z, dst); src, filter, bias, z, dst);
@@ -235,7 +255,7 @@ SmallVector<TensorLayout> ConvBiasImpl::deduce_preprocessed_filter_layout(
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst) { const TensorLayout& dst) {
auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr); auto fparam = make_ncb_kern_size_param(src, filter, bias, dst, nullptr);
Algorithm* algo = get_algorithm(fparam);
auto&& algo = get_algorithm(fparam);
if (is_naive_algo(algo)) { if (is_naive_algo(algo)) {
return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout( return naive::ConvBiasForwardImpl::deduce_preprocessed_filter_layout(
src, filter, bias, z, dst); src, filter, bias, z, dst);
@@ -443,7 +463,7 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
MEGDNN_MARK_USED_VAR(param); MEGDNN_MARK_USED_VAR(param);
std::vector<Algorithm*> algos; std::vector<Algorithm*> algos;
std::vector<Algorithm*> prefer_algos; std::vector<Algorithm*> prefer_algos;
for (auto&& algo : algo_pack()) {
for (auto&& algo : get_all_packed_algo()) {
if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) { if (algo->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
if (algo->is_preferred(param)) { if (algo->is_preferred(param)) {
prefer_algos.push_back(algo); prefer_algos.push_back(algo);
@@ -457,10 +477,49 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
return algos; return algos;
} }


ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc(
const AlgorithmDesc& desc) const {
if (!desc.valid()) {
return nullptr;
} else {
switch (desc.handle_type) {
case Handle::HandleType::FALLBACK: {
const auto& map = algo_pack().all_algos_map();
megdnn_assert(map.find(desc) != map.end());
return map.at(desc);
};

#if MEGDNN_X86
case Handle::HandleType::X86:
return x86::ConvBiasImpl::get_algo_from_desc(desc);
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
case Handle::HandleType::ARM_COMMON:
return arm_common::ConvBiasImpl::get_algo_from_desc(desc);
#if MEGDNN_AARCH64
case Handle::HandleType::AARCH64:
return aarch64::ConvBiasImpl::get_algo_from_desc(desc);
#else
case Handle::HandleType::ARMV7:
return armv7::ConvBiasImpl::get_algo_from_desc(desc);
#endif
#endif
case Handle::HandleType::NAIVE: {
auto algo = static_cast<naive::HandleImpl*>(handle())
->default_conv_bias_fwd_algo();
megdnn_assert(algo->info().desc == desc);
return algo;
}
default:
megdnn_throw("Unknown handle type");
return nullptr;
}
}
}

ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
const NCBKernSizeParam& param, size_t workspace_size) { const NCBKernSizeParam& param, size_t workspace_size) {
if (auto set = execution_policy().algorithm) {
return set;
if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) {
return algo;
} }
if (!m_prev_selected_algo || if (!m_prev_selected_algo ||
memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) { memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {


+ 87
- 1
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -216,6 +216,86 @@ public:
AlgoBase() : Algorithm() { AlgoBase() : Algorithm() {
m_handle_type = Handle::HandleType::FALLBACK; m_handle_type = Handle::HandleType::FALLBACK;
} }

enum class AlgoType : uint32_t {
//! fallback
FB_NAIVE = 1 << 0,
FB_WINOGRAD_F32,
FB_WINOGRAD_4X4_F32,
FB_WINOGRAD_QS8,
FB_WINOGRAD_8X8_QS8,
FB_CONV1x1,
FB_CONV1x1_GEMV,
FB_IM2COL,

#if MEGDNN_X86
X86_DIRECT = 1 << 8,
X86_DIRECT_STRD2,
X86_WINOGRAD_F63_8x8_F32,
X86_WINOGRAD_F23_8x8_F32,
X86_MKLDNN,
X86_CHANWISE_AVX2_STRD1_QINT8,
X86_CHANWISE_AVX2_STRD2_QINT8,
X86_DIRECT_AVX2_STRD1_INT8,
X86_DIRECT_AVX2_STRD2_INT8,
X86_MKLDNN_QINT8,
X86_MKLDNN_MATMUL_QINT8,
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
ARM_COMMON_WINOGRAD_F23_FP16 = 1 << 8,
ARM_COMMON_WINOGRAD_F45_FP16,
ARM_COMMON_WINOGRAD_F63_FP16,
ARM_COMMON_WINOGRAD_F23_8X8_FP16,
ARM_COMMON_DIRECT_FP16,
ARM_COMMON_DIRECT_STRD1_FP16,
ARM_COMMON_WINOGRAD_F23_4X4_FP32,
ARM_COMMON_WINOGRAD_F63_FP32,
ARM_COMMON_WINOGRAD_F63_4X4_FP32,
ARM_COMMON_WINOGRAD_F54_FP32,
ARM_COMMON_WINOGRAD_F45_FP32,
ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32,
ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32,
ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32,
ARM_COMMON_DIRECT_FP32,
ARM_COMMON_DIRECT_STRD1_FP32,
ARM_COMMON_DIRECT_STRD2_FP32,
ARM_COMMON_DIRECT_NCHW44_FP32,
ARM_COMMON_DIRECT_NCHW_NCHW44_FP32,
ARM_COMMON_CHWNWISE_NCHW44_F32,
ARM_COMMON_DIRECT_STRD1_S8,
ARM_COMMON_DIRECT_STRD2_S8,
ARM_COMMON_DIRECT_NCHW44,
ARM_COMMON_DIRECT_NCHW_NCHW44_S8,
ARM_COMMON_CHANWISE_STRD1_NCHW44_S8,
ARM_COMMON_CHANWISE_STRD2_NCHW44_S8,
ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8,
ARM_COMMON_DIRECT_STRD1_DOT_S8,
ARM_COMMON_DIRECT_STRD2_DOT_S8,
ARM_COMMON_DIRECT_NCHW44_DOT_S8,
ARM_COMMON_WINOGRAD_F23_8X8_S8,
ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32,
ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8,
ARM_COMMON_DIRECT_INT8X8X16,
ARM_COMMON_DIRECT_NCHW44_INT8X8X16,
ARM_COMMON_DIRECT_STRD2_INT8X8X16,
ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16,
ARM_COMMON_CHWNWISE_STRD1_STRD2_NCHW44_INT8X8X16,
ARM_COMMON_DIRECT_NCHW_NCHW44_INT8X8X16,
ARM_COMMON_DIRECT_STRD1_QU8,
ARM_COMMON_DIRECT_STRD2_QU8,
ARM_COMMON_DIRECT_STRD1_DOT_QU8,
ARM_COMMON_DIRECT_STRD2_DOT_QU8,
#if MEGDNN_AARCH64
AARCH64_DIRECT_STRD2_FP16,
AARCH64_DIRECT_STRD2_FP32,
AARCH64_MATMUL_S8,
AARCH64_MATMUL_QU8,
#else
ARMV7_MATMUL_S8,
ARMV7_MATMUL_QU8,
#endif // MEGDNN_AARCH64
#endif
};

virtual ~AlgoBase() = default; virtual ~AlgoBase() = default;
virtual bool usable( virtual bool usable(
const NCBKernSizeParam& param, const NCBKernSizeParam& param,
@@ -255,12 +335,14 @@ public:


//! get the type of the algo //! get the type of the algo
virtual ConvAlgoTypePack get_algo_type() const = 0; virtual ConvAlgoTypePack get_algo_type() const = 0;
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
}; };


using AlgoMapper = AlgoBase::Mapper;
/** /**
* \brief get all the algorithm for the opr. * \brief get all the algorithm for the opr.
*/ */
virtual SmallVector<AlgoBase*> algo_pack();
virtual SmallVector<AlgoBase*> get_all_packed_algo();


/** /**
* \brief select algo according to input algo type * \brief select algo according to input algo type
@@ -305,6 +387,8 @@ private:


bool is_naive_algo(ConvBiasImpl::Algorithm* algo); bool is_naive_algo(ConvBiasImpl::Algorithm* algo);


Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const;

//! get algorithm set by user or by heuristic //! get algorithm set by user or by heuristic
Algorithm* get_algorithm( Algorithm* get_algorithm(
const NCBKernSizeParam& param, const NCBKernSizeParam& param,
@@ -320,6 +404,8 @@ private:
_megdnn_tensor_in bias, _megdnn_tensor_out dst, _megdnn_tensor_in bias, _megdnn_tensor_out dst,
_megdnn_workspace workspace, _megdnn_workspace workspace,
const PreprocessedFilter* preprocessed_filter); const PreprocessedFilter* preprocessed_filter);

static const AlgoPack& algo_pack();
}; };


inline bool is_enable_filter_preprocess( inline bool is_enable_filter_preprocess(


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save