You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 3.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /**
  2. * \file dnn/src/arm_common/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./opr_impl.h"
  12. #include "./int8x8x32/algos.h"
  13. #include "./quint8/algos.h"
  14. #include "src/common/metahelper.h"
  15. #include "src/common/utils.h"
  16. #include "src/naive/handle.h"
  17. #include "src/common/opr_delegate.h"
  18. using namespace megdnn;
  19. using namespace arm_common;
  20. namespace {
  21. uint8_t arm_common_algo_type_storage;
  22. } // anonymous namespace
  23. /* ===================== ConvolutionBackwardData ===================== */
  24. struct ConvolutionBackwardDataImpl::AlgoPack {
  25. #if __ARM_FEATURE_DOTPROD
  26. AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot;
  27. AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot;
  28. AlgoUdot8DirectStride1 quint8_direct_stride1_udot;
  29. AlgoUdot8DirectStride2 quint8_direct_stride2_udot;
  30. #endif
  31. };
  32. ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
  33. void* const ConvolutionBackwardDataImpl::sm_arm_common_algo_type =
  34. &arm_common_algo_type_storage;
  35. ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  36. Algorithm* algo, const NCBKernSizeParam& param) {
  37. if (algo->type() == sm_arm_common_algo_type) {
  38. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  39. }
  40. return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, param);
  41. }
  42. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(Algorithm* algo,
  43. const NCBKernSizeParam& param) {
  44. if (algo->type() == sm_arm_common_algo_type) {
  45. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  46. }
  47. return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, param);
  48. }
  49. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  50. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) {
  51. auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(param);
  52. #if __ARM_FEATURE_DOTPROD
  53. if((param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
  54. param.filter_type.enumv() == DTypeEnum::Int8) &&
  55. (param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
  56. param.grad_type.enumv() == DTypeEnum::Int32)) {
  57. if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) {
  58. ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot);
  59. }
  60. if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) {
  61. ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot);
  62. }
  63. }
  64. else if(param.filter_type.enumv() == DTypeEnum::Quantized8Asymm &&
  65. param.grad_type.enumv() == DTypeEnum::QuantizedS32) {
  66. if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) {
  67. ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot);
  68. }
  69. if (sm_algo_pack.quint8_direct_stride2_udot.usable(this, param)) {
  70. ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride2_udot);
  71. }
  72. }
  73. #endif
  74. return ret;
  75. }
  76. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  77. // arm common version 0
  78. return "DeconvAC0";
  79. }
  80. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台