Browse Source

fix(mgb/dnn): fix cuda naive matmul algo

GitOrigin-RevId: 79c9bba73b
tags/v1.4.0-rc1
Megvii Engine Team 5 years ago
parent
commit
ef9aa80074
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      dnn/src/cuda/matrix_mul/algos.cpp
  2. +1
    -1
      dnn/src/cuda/matrix_mul/algos.h

+ 1
- 1
dnn/src/cuda/matrix_mul/algos.cpp View File

@@ -29,7 +29,6 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
#if CUDA_VERSION >= 10010 #if CUDA_VERSION >= 10010
all_algos.push_back(&cublas_lt); all_algos.push_back(&cublas_lt);
#endif #endif
all_algos.push_back(&naive);
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
all_algos.push_back(&bfloat16); all_algos.push_back(&bfloat16);
#endif #endif
@@ -45,6 +44,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo); all_algos.push_back(&algo);
} }
#endif #endif
all_algos.push_back(&naive);


for (auto&& algo : all_algos) { for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo); m_all_algos_map.emplace(algo->info().desc, algo);


+ 1
- 1
dnn/src/cuda/matrix_mul/algos.h View File

@@ -157,7 +157,7 @@ public:
void exec(const ExecArgs& args) const override; void exec(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE)
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
} }
}; };




Loading…
Cancel
Save