You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. /**
  2. * \file dnn/src/fallback/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/fallback/convolution/opr_impl.h"
  13. #include "src/common/algo_chooser.h"
  14. #include "src/common/metahelper.h"
  15. #include "src/common/opr_delegate.h"
  16. #include "src/common/utils.h"
  17. #include "src/fallback/convolution/algos.h"
  18. #include "src/fallback/convolution/run_conv.h"
  19. #include "src/naive/convolution/helper.h"
  20. #include "src/naive/handle.h"
  21. #include "midout.h"
  22. #include <cstring>
  23. MIDOUT_DECL(megdnn_fb_conv_float)
  24. MIDOUT_DECL(megdnn_fb_convbwd_float)
  25. using namespace megdnn;
  26. using namespace fallback;
  27. namespace {
  28. class NaiveConvolutionBackwardData final
  29. : public megdnn::ConvolutionBackwardData::Algorithm {
  30. bool is_reproducible() const override { return true; }
  31. const char* name() const override { return "NCBD"; }
  32. };
  33. NaiveConvolutionBackwardData naive_conv_backward_data;
  34. uint8_t fallback_deconv_algo_type_storage;
  35. uint8_t fallback_conv_algo_type_storage;
  36. template <typename T>
  37. void incr_ptr(T*& dst, ptrdiff_t delta) {
  38. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  39. }
  40. } // namespace
  41. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  42. AlgoFallback algo_fallback;
  43. AlgoNaive algo_naive;
  44. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  45. public:
  46. AlgoPack() {
  47. static CpuOprDelegationStorage<1> storage;
  48. auto conv_bias_opr = storage.get<ConvBias, 0>();
  49. auto&& conv_bias_algo =
  50. static_cast<ConvBiasImpl*>(conv_bias_opr)->algo_pack();
  51. for (auto&& algorithm : conv_bias_algo) {
  52. // fallback algo
  53. refhold.emplace_back(new AlgoDefault(
  54. static_cast<ConvBiasImpl*>(conv_bias_opr), algorithm));
  55. all_algos.emplace_back(refhold.back().get());
  56. }
  57. all_algos.emplace_back(&algo_fallback);
  58. all_algos.emplace_back(&algo_naive);
  59. }
  60. SmallVector<AlgoBase*> all_algos;
  61. };
  62. void* const ConvolutionImpl::sm_fallback_conv_algo_type =
  63. &fallback_conv_algo_type_storage;
  64. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
  65. static AlgoPack sl_algo_pack;
  66. return sl_algo_pack.all_algos;
  67. }
  68. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  69. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  70. }
  71. void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  72. _megdnn_tensor_out dst,
  73. const PreprocessedFilter* preprocessed_filter,
  74. _megdnn_workspace workspace) {
  75. auto fparam = make_ncb_kern_param(src, filter, dst, workspace);
  76. ConvolutionImpl::Algorithm* algo = get_algorithm(fparam, workspace.size);
  77. if (!is_naive_algo(algo) &&
  78. ncb_algo_get_workspace(algo, fparam) <= workspace.size) {
  79. exec_with_ncb_kern(fparam, algo);
  80. } else {
  81. naive::ConvolutionForwardImpl::exec(src, filter, dst,
  82. preprocessed_filter, workspace);
  83. }
  84. }
  85. size_t ConvolutionImpl::get_workspace_in_bytes(
  86. const TensorLayout& src, const TensorLayout& filter,
  87. const TensorLayout& dst,
  88. const PreprocessedFilter* preprocessed_filter) {
  89. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  90. Algorithm* algo = get_algorithm(fparam);
  91. if (is_naive_algo(algo)) {
  92. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  93. src, filter, dst, preprocessed_filter);
  94. } else {
  95. return ncb_algo_get_workspace(algo, fparam);
  96. }
  97. }
  98. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  99. const TensorLayout& src, const TensorLayout& filter,
  100. const TensorLayout& dst) {
  101. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  102. auto ret = get_all_algorithms_with_ncb(fparam);
  103. if (ret.empty()) {
  104. return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter,
  105. dst);
  106. }
  107. return ret;
  108. }
  109. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  110. const TensorLayout& src, const TensorLayout& filter,
  111. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  112. bool reproducible) {
  113. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  114. auto result = get_algorithm_heuristic_with_ncb(
  115. fparam, workspace_limit_in_bytes, reproducible);
  116. if (result == nullptr) {
  117. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  118. src, filter, dst, workspace_limit_in_bytes, reproducible);
  119. }
  120. return result;
  121. }
  122. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  123. const TensorLayout& src, const TensorLayout& filter,
  124. const TensorLayout& dst) {
  125. auto safe_u32 = [](size_t v) -> uint32_t {
  126. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  127. "value too large: %zu", v);
  128. return v;
  129. };
  130. size_t spatial_pos;
  131. if (param().format == Param::Format::NCHW88 ||
  132. param().format == Param::Format::NCHW8 ||
  133. param().format == Param::Format::NCHW4 ||
  134. param().format == Param::Format::NCHW44_DOT ||
  135. param().format == Param::Format::NCHW44) {
  136. spatial_pos = 2;
  137. } else if (param().format == Param::Format::NCHW ||
  138. param().format == Param::Format::NCHW_WINOGRAD) {
  139. spatial_pos = 2;
  140. } else if (param().format == Param::Format::NHWC) {
  141. spatial_pos = 1;
  142. } else {
  143. megdnn_assert(0, "invalid conv format %d",
  144. static_cast<int>(param().format));
  145. }
  146. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  147. ->megcore_dispatcher()
  148. ->nr_threads();
  149. return {safe_u32(src[0]),
  150. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  151. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  152. check_layout_fwd(src, filter, dst),
  153. src.dtype,
  154. filter.dtype,
  155. dst.dtype,
  156. src.stride[0],
  157. dst.stride[0],
  158. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  159. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  160. param().compute_mode,
  161. nr_threads};
  162. }
  163. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  164. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  165. _megdnn_workspace workspace) {
  166. NCBKernParam ret;
  167. static_cast<NCBKernSizeParam&>(ret) =
  168. make_ncb_kern_size_param(src.layout, filter.layout, dst.layout);
  169. ret.src_ptr = src.raw_ptr;
  170. ret.filter_ptr = filter.raw_ptr;
  171. ret.dst_ptr = dst.raw_ptr;
  172. ret.workspace_ptr = workspace.raw_ptr;
  173. ret.workspace_size = workspace.size;
  174. return ret;
  175. }
  176. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
  177. Algorithm* algo) {
  178. auto kerns = ncb_algo_dispatch_kern(algo, param);
  179. auto fallback_handle = handle();
  180. for (auto kernel : kerns) {
  181. megdnn_assert(param.filter_meta.format == Param::Format::NCHW ||
  182. param.filter_meta.format == Param::Format::NHWC ||
  183. param.filter_meta.format == Param::Format::NCHW88 ||
  184. param.filter_meta.format == Param::Format::NCHW44,
  185. "invalid conv format");
  186. auto run = [param, kernel](size_t index, size_t thread_id) {
  187. CpuNDRange ndrange_id(kernel.global_size, index);
  188. kernel.kern(param, {thread_id, ndrange_id});
  189. };
  190. static_cast<naive::HandleImpl*>(fallback_handle)
  191. ->dispatch_kern(run, kernel.global_size.total_size());
  192. }
  193. }
  194. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  195. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  196. bool reproducible) {
  197. for (auto i : get_all_algorithms_with_ncb(param)) {
  198. if (static_cast<AlgoBase*>(i)->usable_reproducible(
  199. this, param, AlgoSelectionStrategy::HEURISTIC,
  200. reproducible) &&
  201. ncb_algo_get_workspace(i, param) <= workspace_limit_in_bytes) {
  202. return i;
  203. }
  204. }
  205. return nullptr;
  206. }
  207. std::vector<ConvolutionImpl::Algorithm*>
  208. ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  209. std::vector<Algorithm*> ret;
  210. std::vector<Algorithm*> prefer_algos;
  211. for (auto&& i : algo_pack()) {
  212. if (i->usable(this, param, AlgoSelectionStrategy::FULL_RUN)) {
  213. if (i->is_preferred(this, param)) {
  214. prefer_algos.push_back(i);
  215. } else {
  216. ret.push_back(i);
  217. }
  218. }
  219. }
  220. std::reverse(prefer_algos.begin(), prefer_algos.end());
  221. //! Prefer algo inserted from begin
  222. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  223. return ret;
  224. }
  225. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  226. const NCBKernSizeParam& param, size_t workspace_size) {
  227. if (auto set = execution_policy().algorithm) {
  228. return set;
  229. }
  230. if (!m_prev_selected_algo ||
  231. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  232. m_prev_selected_algo =
  233. get_algorithm_heuristic_with_ncb(param, workspace_size);
  234. m_prev_selected_algo_sizep = param;
  235. }
  236. return m_prev_selected_algo;
  237. }
  238. const char* ConvolutionImpl::get_algorithm_set_name() const {
  239. // fallback version 0
  240. return "F0";
  241. }
  242. /* ===================== ConvolutionBackwardData ===================== */
  243. void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type =
  244. &fallback_deconv_algo_type_storage;
  245. struct ConvolutionBackwardDataImpl::AlgoPack {
  246. AlgoDirect direct;
  247. AlgoMatrixMul matmul;
  248. };
  249. ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
  250. void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
  251. _megdnn_tensor_in diff,
  252. _megdnn_tensor_out grad,
  253. _megdnn_workspace workspace) {
  254. if (param().format == param::Convolution::Format::NHWCD4 ||
  255. param().format == param::Convolution::Format::NCHW4) {
  256. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad,
  257. workspace);
  258. }
  259. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  260. return exec_with_ncb_kern(fparam);
  261. }
  262. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  263. const TensorLayout& filter, const TensorLayout& diff,
  264. const TensorLayout& grad) {
  265. if (param().format == param::Convolution::Format::NHWCD4 ||
  266. param().format == param::Convolution::Format::NCHW4) {
  267. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  268. filter, diff, grad);
  269. }
  270. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  271. return get_workspace_with_ncb(fparam);
  272. }
  273. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  274. ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
  275. const TensorLayout& diff,
  276. const TensorLayout& grad) {
  277. if (param().format == param::Convolution::Format::NHWCD4 ||
  278. param().format == param::Convolution::Format::NCHW4) {
  279. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  280. filter, diff, grad);
  281. }
  282. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  283. auto ret = get_all_algorithms_with_ncb(fparam);
  284. megdnn_assert(!ret.empty(), "no usable conv fwd algorithm");
  285. return ret;
  286. }
  287. ConvolutionBackwardDataImpl::Algorithm*
  288. ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  289. const TensorLayout& filter, const TensorLayout& diff,
  290. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  291. bool reproducible) {
  292. if (param().format == param::Convolution::Format::NHWCD4 ||
  293. param().format == param::Convolution::Format::NCHW4) {
  294. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  295. filter, diff, grad, workspace_limit_in_bytes, reproducible);
  296. }
  297. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  298. return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
  299. reproducible);
  300. }
  301. ConvolutionBackwardDataImpl::NCBKernSizeParam
  302. ConvolutionBackwardDataImpl::make_ncb_kern_size_param(
  303. const TensorLayout& filter, const TensorLayout& diff,
  304. const TensorLayout& grad) {
  305. auto safe_u32 = [](size_t v) -> uint32_t {
  306. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  307. "value too large: %zu", v);
  308. return v;
  309. };
  310. size_t spatial_pos;
  311. if (param().format == Param::Format::NCHW) {
  312. spatial_pos = 2;
  313. } else {
  314. megdnn_assert(param().format == Param::Format::NHWC,
  315. "invalid conv format");
  316. spatial_pos = 1;
  317. }
  318. auto grad_fwd = grad;
  319. auto filter_fwd = filter;
  320. auto diff_fwd = diff;
  321. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  322. return {
  323. safe_u32(diff[0]),
  324. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  325. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  326. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  327. diff.dtype,
  328. filter.dtype,
  329. grad.dtype,
  330. diff,
  331. filter,
  332. grad,
  333. diff.stride[0],
  334. grad.stride[0],
  335. 0,
  336. 0,
  337. 0,
  338. param().compute_mode,
  339. };
  340. }
  341. ConvolutionBackwardDataImpl::NCBKernParam
  342. ConvolutionBackwardDataImpl::make_ncb_kern_param(_megdnn_tensor_in filter,
  343. _megdnn_tensor_in diff,
  344. _megdnn_tensor_out grad,
  345. _megdnn_workspace workspace) {
  346. NCBKernParam ret;
  347. static_cast<NCBKernSizeParam&>(ret) =
  348. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  349. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  350. megdnn_assert(workspace.size >= required_workspace_in_bytes,
  351. "required workspace: %zu; provided workspace: %zu",
  352. required_workspace_in_bytes, workspace.size);
  353. ret.filter_ptr = filter.raw_ptr;
  354. ret.diff_ptr = diff.raw_ptr;
  355. ret.grad_ptr = grad.raw_ptr;
  356. ret.workspace_ptr = workspace.raw_ptr;
  357. ret.workspace_size = workspace.size;
  358. return ret;
  359. }
  360. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(
  361. const NCBKernParam& param) {
  362. auto p1g = param;
  363. auto group = p1g.filter_meta.group;
  364. p1g.filter_meta.group = 1;
  365. auto algo = get_algorithm(p1g);
  366. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  367. if (algo == &naive_conv_backward_data || group == 1) {
  368. auto run = [kptr, param]() { kptr(param); };
  369. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  370. } else {
  371. megdnn_assert(p1g.filter_meta.format == Param::Format::NCHW ||
  372. p1g.filter_meta.format == Param::Format::NHWC,
  373. "invalid conv format");
  374. auto run = [kptr, p1g_orig = p1g, group]() {
  375. auto p1g = p1g_orig;
  376. ptrdiff_t istrd, fstrd, ostrd;
  377. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  378. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  379. p1g.filter_type.size();
  380. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  381. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  382. p1g.diff_extra_mem_size =
  383. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  384. p1g.filter_extra_mem_size =
  385. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  386. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  387. p1g.filter_type.size();
  388. p1g.grad_extra_mem_size =
  389. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  390. if (p1g.filter_meta.format == Param::Format::NCHW) {
  391. istrd *= p1g.isz[0] * p1g.isz[1];
  392. ostrd *= p1g.osz[0] * p1g.osz[1];
  393. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  394. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  395. } else {
  396. // must be NHWC. No action performed.
  397. }
  398. for (size_t i = 0; i < group; ++i) {
  399. kptr(p1g);
  400. incr_ptr(p1g.diff_ptr, istrd);
  401. incr_ptr(p1g.filter_ptr, fstrd);
  402. incr_ptr(p1g.grad_ptr, ostrd);
  403. p1g.diff_extra_mem_size -= istrd;
  404. p1g.filter_extra_mem_size -= fstrd;
  405. p1g.grad_extra_mem_size -= ostrd;
  406. }
  407. };
  408. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  409. }
  410. }
  411. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  412. const NCBKernSizeParam& param) {
  413. if (param.filter_meta.group != 1) {
  414. auto p1g = param;
  415. p1g.filter_meta.group = 1;
  416. return ncb_1g_get_workspace(get_algorithm(p1g), p1g);
  417. }
  418. return ncb_1g_get_workspace(get_algorithm(param), param);
  419. }
  420. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  421. ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
  422. const NCBKernSizeParam& param) {
  423. if (param.filter_meta.group != 1) {
  424. auto p1g = param;
  425. p1g.filter_meta.group = 1;
  426. return ncb_1g_get_all_algorithms(p1g);
  427. }
  428. return ncb_1g_get_all_algorithms(param);
  429. }
  430. ConvolutionBackwardDataImpl::Algorithm*
  431. ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
  432. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  433. bool reproducible) {
  434. if (param.filter_meta.group != 1) {
  435. auto p1g = param;
  436. p1g.filter_meta.group = 1;
  437. return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
  438. reproducible);
  439. }
  440. return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
  441. reproducible);
  442. }
  443. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  444. Algorithm* algo, const NCBKernSizeParam& param) {
  445. megdnn_assert(param.filter_meta.group == 1);
  446. if (algo->type() == sm_fallback_deconv_algo_type) {
  447. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  448. }
  449. megdnn_assert(algo == &naive_conv_backward_data);
  450. return 0;
  451. }
  452. ConvolutionBackwardDataImpl::ncb_kern_t
  453. ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  454. Algorithm* algo, const NCBKernSizeParam& param) {
  455. megdnn_assert(param.filter_meta.group == 1);
  456. if (algo->type() == sm_fallback_deconv_algo_type) {
  457. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  458. }
  459. if (algo == &naive_conv_backward_data) {
  460. #define cb(_dt) \
  461. do { \
  462. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  463. MIDOUT_BEGIN(megdnn_fb_convbwd_float, \
  464. midout_iv(DTypeTrait<_dt>::enumv)) { \
  465. using ctype = DTypeTrait<_dt>::ctype; \
  466. return kern_naive<ctype, ctype, ctype>; \
  467. } \
  468. MIDOUT_END(); \
  469. } \
  470. } while (0);
  471. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
  472. #undef cb
  473. #define cb(dt_src, dt_dst) \
  474. do { \
  475. if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
  476. param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
  477. param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
  478. return kern_naive<DTypeTrait<dt_src>::ctype, \
  479. DTypeTrait<dt_src>::ctype, \
  480. DTypeTrait<dt_dst>::ctype>; \
  481. } \
  482. } while (0);
  483. cb(dtype::Int8, dtype::Int32) cb(dtype::Quantized8Asymm,
  484. dtype::QuantizedS32)
  485. cb(dtype::QuantizedS8, dtype::QuantizedS32) megdnn_throw(
  486. "unsupported data type on ConvolutionBackwardData");
  487. #undef cb
  488. }
  489. megdnn_throw(
  490. megdnn_mangle("no suitable ConvolutionBackwardData algorithm"));
  491. }
  492. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  493. const NCBKernSizeParam& param) {
  494. auto&& fm = param.filter_meta;
  495. auto OC = fm.ocpg, IC = fm.icpg;
  496. return (OC * IC >= 32) ||
  497. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  498. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  499. }
  500. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  501. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
  502. const NCBKernSizeParam& param) {
  503. std::vector<Algorithm*> ret;
  504. ret.reserve(2);
  505. ret.push_back(&naive_conv_backward_data);
  506. // insert from lowest to highest preference
  507. AlgoBase* cand[2] = {nullptr};
  508. if (param.filter_meta.group == 1 && param.filter_meta.dilation[0] == 1 &&
  509. param.filter_meta.dilation[1] == 1) {
  510. // we currently only have non-dilated algos
  511. if (param.filter_type.enumv() == DTypeEnum::Float32) {
  512. if (is_matrix_mul_preferred(param)) {
  513. cand[0] = &sm_algo_pack.direct;
  514. cand[1] = &sm_algo_pack.matmul;
  515. } else {
  516. cand[0] = &sm_algo_pack.matmul;
  517. cand[1] = &sm_algo_pack.direct;
  518. }
  519. } else {
  520. cand[0] = &sm_algo_pack.matmul;
  521. }
  522. }
  523. for (auto i : cand) {
  524. if (i && i->usable(this, param)) {
  525. ret.push_back(i);
  526. }
  527. }
  528. std::reverse(ret.begin(), ret.end());
  529. return ret;
  530. }
  531. ConvolutionBackwardDataImpl::Algorithm*
  532. ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
  533. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  534. bool reproducible) {
  535. for (auto i : ncb_1g_get_all_algorithms(param)) {
  536. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  537. if (reproducible) {
  538. if (i->is_reproducible()) {
  539. return i;
  540. }
  541. } else {
  542. return i;
  543. }
  544. }
  545. }
  546. megdnn_assert(0,
  547. "no suitable algorithm found within given workspace limit");
  548. }
  549. ConvolutionBackwardDataImpl::Algorithm*
  550. ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
  551. if (auto set = execution_policy().algorithm) {
  552. return set;
  553. }
  554. if (!m_prev_selected_algo ||
  555. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  556. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  557. param, std::numeric_limits<size_t>::max());
  558. m_prev_selected_algo_sizep = param;
  559. }
  560. return m_prev_selected_algo;
  561. }
  562. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  563. // fallback version 0
  564. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  565. }
  566. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台