You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. /**
  2. * \file dnn/src/fallback/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/fallback/convolution/opr_impl.h"
  13. #include "src/common/algo_chooser.h"
  14. #include "src/common/metahelper.h"
  15. #include "src/common/opr_delegate.h"
  16. #include "src/common/utils.h"
  17. #include "src/fallback/convolution/algos.h"
  18. #include "src/fallback/convolution/run_conv.h"
  19. #include "src/naive/convolution/helper.h"
  20. #include "src/naive/handle.h"
  21. #include "midout.h"
  22. #include <cstring>
  23. MIDOUT_DECL(megdnn_fb_conv_float)
  24. MIDOUT_DECL(megdnn_fb_convbwd_float)
  25. using namespace megdnn;
  26. using namespace fallback;
  27. namespace {
  28. class NaiveConvolutionBackwardData final
  29. : public megdnn::ConvolutionBackwardData::Algorithm {
  30. bool is_reproducible() const override { return true; }
  31. const char* name() const override { return "NCBD"; }
  32. };
  33. NaiveConvolutionBackwardData naive_conv_backward_data;
  34. uint8_t fallback_deconv_algo_type_storage;
  35. uint8_t fallback_conv_algo_type_storage;
  36. template <typename T>
  37. void incr_ptr(T*& dst, ptrdiff_t delta) {
  38. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  39. }
  40. } // namespace
  41. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  42. AlgoFallback algo_fallback;
  43. AlgoNaive algo_naive;
  44. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  45. public:
  46. AlgoPack() {
  47. static CpuOprDelegationStorage<1> storage;
  48. auto conv_bias_opr = storage.get<ConvBias, 0>();
  49. auto&& conv_bias_algo =
  50. static_cast<ConvBiasImpl*>(conv_bias_opr)->algo_pack();
  51. for (auto&& algorithm : conv_bias_algo) {
  52. // fallback algo
  53. refhold.emplace_back(new AlgoDefault(
  54. static_cast<ConvBiasImpl*>(conv_bias_opr), algorithm));
  55. all_algos.emplace_back(refhold.back().get());
  56. }
  57. all_algos.emplace_back(&algo_fallback);
  58. all_algos.emplace_back(&algo_naive);
  59. }
  60. SmallVector<AlgoBase*> all_algos;
  61. };
  62. void* const ConvolutionImpl::sm_fallback_conv_algo_type =
  63. &fallback_conv_algo_type_storage;
  64. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
  65. static AlgoPack sl_algo_pack;
  66. return sl_algo_pack.all_algos;
  67. }
  68. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  69. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  70. }
  71. void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  72. _megdnn_tensor_out dst,
  73. _megdnn_workspace workspace) {
  74. auto fparam = make_ncb_kern_param(src, filter, dst, workspace);
  75. ConvolutionImpl::Algorithm* algo = get_algorithm(fparam, workspace.size);
  76. if (!is_naive_algo(algo) &&
  77. ncb_algo_get_workspace(algo, fparam) <= workspace.size) {
  78. exec_with_ncb_kern(fparam, algo);
  79. } else {
  80. naive::ConvolutionForwardImpl::exec(src, filter, dst, workspace);
  81. }
  82. }
  83. size_t ConvolutionImpl::get_workspace_in_bytes(const TensorLayout& src,
  84. const TensorLayout& filter,
  85. const TensorLayout& dst) {
  86. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  87. Algorithm* algo = get_algorithm(fparam);
  88. if (is_naive_algo(algo)) {
  89. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  90. src, filter, dst);
  91. } else {
  92. return ncb_algo_get_workspace(algo, fparam);
  93. }
  94. }
  95. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  96. const TensorLayout& src, const TensorLayout& filter,
  97. const TensorLayout& dst) {
  98. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  99. auto ret = get_all_algorithms_with_ncb(fparam);
  100. if (ret.empty()) {
  101. return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter,
  102. dst);
  103. }
  104. return ret;
  105. }
  106. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  107. const TensorLayout& src, const TensorLayout& filter,
  108. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  109. bool reproducible) {
  110. auto fparam = make_ncb_kern_size_param(src, filter, dst);
  111. auto result = get_algorithm_heuristic_with_ncb(
  112. fparam, workspace_limit_in_bytes, reproducible);
  113. if (result == nullptr) {
  114. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  115. src, filter, dst, workspace_limit_in_bytes, reproducible);
  116. }
  117. return result;
  118. }
  119. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  120. const TensorLayout& src, const TensorLayout& filter,
  121. const TensorLayout& dst) {
  122. auto safe_u32 = [](size_t v) -> uint32_t {
  123. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  124. "value too large: %zu", v);
  125. return v;
  126. };
  127. size_t spatial_pos;
  128. if (param().format == Param::Format::NCHW88 ||
  129. param().format == Param::Format::NCHW8 ||
  130. param().format == Param::Format::NCHW4 ||
  131. param().format == Param::Format::NCHW44) {
  132. spatial_pos = 2;
  133. } else if (param().format == Param::Format::NCHW ||
  134. param().format == Param::Format::NCHW_WINOGRAD) {
  135. spatial_pos = 2;
  136. } else if (param().format == Param::Format::NHWC) {
  137. spatial_pos = 1;
  138. } else {
  139. megdnn_assert(0, "invalid conv format %d",
  140. static_cast<int>(param().format));
  141. }
  142. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  143. ->megcore_dispatcher()
  144. ->nr_threads();
  145. return {safe_u32(src[0]),
  146. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  147. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  148. check_layout_fwd(src, filter, dst),
  149. src.dtype,
  150. filter.dtype,
  151. dst.dtype,
  152. src.stride[0],
  153. dst.stride[0],
  154. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  155. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  156. param().compute_mode,
  157. nr_threads};
  158. }
  159. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  160. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  161. _megdnn_workspace workspace) {
  162. NCBKernParam ret;
  163. static_cast<NCBKernSizeParam&>(ret) =
  164. make_ncb_kern_size_param(src.layout, filter.layout, dst.layout);
  165. ret.src_ptr = src.raw_ptr;
  166. ret.filter_ptr = filter.raw_ptr;
  167. ret.dst_ptr = dst.raw_ptr;
  168. ret.workspace_ptr = workspace.raw_ptr;
  169. ret.workspace_size = workspace.size;
  170. return ret;
  171. }
  172. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
  173. Algorithm* algo) {
  174. auto kerns = ncb_algo_dispatch_kern(algo, param);
  175. auto fallback_handle = handle();
  176. for (auto kernel : kerns) {
  177. megdnn_assert(param.filter_meta.format == Param::Format::NCHW ||
  178. param.filter_meta.format == Param::Format::NHWC ||
  179. param.filter_meta.format == Param::Format::NCHW88,
  180. "invalid conv format");
  181. auto run = [param, kernel](size_t index, size_t thread_id) {
  182. CpuNDRange ndrange_id(kernel.global_size, index);
  183. kernel.kern(param, {thread_id, ndrange_id});
  184. };
  185. static_cast<naive::HandleImpl*>(fallback_handle)
  186. ->dispatch_kern(run, kernel.global_size.total_size());
  187. }
  188. }
  189. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  190. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  191. bool reproducible) {
  192. for (auto i : get_all_algorithms_with_ncb(param)) {
  193. if (static_cast<AlgoBase*>(i)->usable_reproducible(
  194. this, param, AlgoSelectionStrategy::HEURISTIC,
  195. reproducible) &&
  196. ncb_algo_get_workspace(i, param) <= workspace_limit_in_bytes) {
  197. return i;
  198. }
  199. }
  200. return nullptr;
  201. }
  202. std::vector<ConvolutionImpl::Algorithm*>
  203. ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  204. std::vector<Algorithm*> ret;
  205. std::vector<Algorithm*> prefer_algos;
  206. for (auto&& i : algo_pack()) {
  207. if (i->usable(this, param, AlgoSelectionStrategy::FULL_RUN)) {
  208. if (i->is_preferred(this, param)) {
  209. prefer_algos.push_back(i);
  210. } else {
  211. ret.push_back(i);
  212. }
  213. }
  214. }
  215. std::reverse(prefer_algos.begin(), prefer_algos.end());
  216. //! Prefer algo inserted from begin
  217. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  218. return ret;
  219. }
  220. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  221. const NCBKernSizeParam& param, size_t workspace_size) {
  222. if (auto set = execution_policy().algorithm) {
  223. return set;
  224. }
  225. if (!m_prev_selected_algo ||
  226. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  227. m_prev_selected_algo =
  228. get_algorithm_heuristic_with_ncb(param, workspace_size);
  229. m_prev_selected_algo_sizep = param;
  230. }
  231. return m_prev_selected_algo;
  232. }
  233. const char* ConvolutionImpl::get_algorithm_set_name() const {
  234. // fallback version 0
  235. return "F0";
  236. }
  237. /* ===================== ConvolutionBackwardData ===================== */
  238. void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type =
  239. &fallback_deconv_algo_type_storage;
  240. struct ConvolutionBackwardDataImpl::AlgoPack {
  241. AlgoDirect direct;
  242. AlgoMatrixMul matmul;
  243. };
  244. ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
  245. void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
  246. _megdnn_tensor_in diff,
  247. _megdnn_tensor_out grad,
  248. _megdnn_workspace workspace) {
  249. if (param().format == param::Convolution::Format::NHWCD4 ||
  250. param().format == param::Convolution::Format::NCHW4) {
  251. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad,
  252. workspace);
  253. }
  254. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  255. return exec_with_ncb_kern(fparam);
  256. }
  257. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  258. const TensorLayout& filter, const TensorLayout& diff,
  259. const TensorLayout& grad) {
  260. if (param().format == param::Convolution::Format::NHWCD4 ||
  261. param().format == param::Convolution::Format::NCHW4) {
  262. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  263. filter, diff, grad);
  264. }
  265. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  266. return get_workspace_with_ncb(fparam);
  267. }
  268. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  269. ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
  270. const TensorLayout& diff,
  271. const TensorLayout& grad) {
  272. if (param().format == param::Convolution::Format::NHWCD4 ||
  273. param().format == param::Convolution::Format::NCHW4) {
  274. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  275. filter, diff, grad);
  276. }
  277. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  278. auto ret = get_all_algorithms_with_ncb(fparam);
  279. megdnn_assert(!ret.empty(), "no usable conv fwd algorithm");
  280. return ret;
  281. }
  282. ConvolutionBackwardDataImpl::Algorithm*
  283. ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  284. const TensorLayout& filter, const TensorLayout& diff,
  285. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  286. bool reproducible) {
  287. if (param().format == param::Convolution::Format::NHWCD4 ||
  288. param().format == param::Convolution::Format::NCHW4) {
  289. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  290. filter, diff, grad, workspace_limit_in_bytes, reproducible);
  291. }
  292. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  293. return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
  294. reproducible);
  295. }
  296. ConvolutionBackwardDataImpl::NCBKernSizeParam
  297. ConvolutionBackwardDataImpl::make_ncb_kern_size_param(
  298. const TensorLayout& filter, const TensorLayout& diff,
  299. const TensorLayout& grad) {
  300. auto safe_u32 = [](size_t v) -> uint32_t {
  301. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  302. "value too large: %zu", v);
  303. return v;
  304. };
  305. size_t spatial_pos;
  306. if (param().format == Param::Format::NCHW) {
  307. spatial_pos = 2;
  308. } else {
  309. megdnn_assert(param().format == Param::Format::NHWC,
  310. "invalid conv format");
  311. spatial_pos = 1;
  312. }
  313. auto grad_fwd = grad;
  314. auto filter_fwd = filter;
  315. auto diff_fwd = diff;
  316. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  317. return {
  318. safe_u32(diff[0]),
  319. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  320. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  321. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  322. diff.dtype,
  323. filter.dtype,
  324. grad.dtype,
  325. diff,
  326. filter,
  327. grad,
  328. diff.stride[0],
  329. grad.stride[0],
  330. 0,
  331. 0,
  332. 0,
  333. param().compute_mode,
  334. };
  335. }
  336. ConvolutionBackwardDataImpl::NCBKernParam
  337. ConvolutionBackwardDataImpl::make_ncb_kern_param(_megdnn_tensor_in filter,
  338. _megdnn_tensor_in diff,
  339. _megdnn_tensor_out grad,
  340. _megdnn_workspace workspace) {
  341. NCBKernParam ret;
  342. static_cast<NCBKernSizeParam&>(ret) =
  343. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  344. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  345. megdnn_assert(workspace.size >= required_workspace_in_bytes,
  346. "required workspace: %zu; provided workspace: %zu",
  347. required_workspace_in_bytes, workspace.size);
  348. ret.filter_ptr = filter.raw_ptr;
  349. ret.diff_ptr = diff.raw_ptr;
  350. ret.grad_ptr = grad.raw_ptr;
  351. ret.workspace_ptr = workspace.raw_ptr;
  352. ret.workspace_size = workspace.size;
  353. return ret;
  354. }
  355. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(
  356. const NCBKernParam& param) {
  357. auto p1g = param;
  358. auto group = p1g.filter_meta.group;
  359. p1g.filter_meta.group = 1;
  360. auto algo = get_algorithm(p1g);
  361. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  362. if (algo == &naive_conv_backward_data || group == 1) {
  363. auto run = [kptr, param]() { kptr(param); };
  364. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  365. } else {
  366. megdnn_assert(p1g.filter_meta.format == Param::Format::NCHW ||
  367. p1g.filter_meta.format == Param::Format::NHWC,
  368. "invalid conv format");
  369. auto run = [kptr, p1g_orig = p1g, group]() {
  370. auto p1g = p1g_orig;
  371. ptrdiff_t istrd, fstrd, ostrd;
  372. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  373. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  374. p1g.filter_type.size();
  375. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  376. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  377. p1g.diff_extra_mem_size =
  378. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  379. p1g.filter_extra_mem_size =
  380. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  381. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  382. p1g.filter_type.size();
  383. p1g.grad_extra_mem_size =
  384. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  385. if (p1g.filter_meta.format == Param::Format::NCHW) {
  386. istrd *= p1g.isz[0] * p1g.isz[1];
  387. ostrd *= p1g.osz[0] * p1g.osz[1];
  388. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  389. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  390. } else {
  391. // must be NHWC. No action performed.
  392. }
  393. for (size_t i = 0; i < group; ++i) {
  394. kptr(p1g);
  395. incr_ptr(p1g.diff_ptr, istrd);
  396. incr_ptr(p1g.filter_ptr, fstrd);
  397. incr_ptr(p1g.grad_ptr, ostrd);
  398. p1g.diff_extra_mem_size -= istrd;
  399. p1g.filter_extra_mem_size -= fstrd;
  400. p1g.grad_extra_mem_size -= ostrd;
  401. }
  402. };
  403. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  404. }
  405. }
  406. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  407. const NCBKernSizeParam& param) {
  408. if (param.filter_meta.group != 1) {
  409. auto p1g = param;
  410. p1g.filter_meta.group = 1;
  411. return ncb_1g_get_workspace(get_algorithm(p1g), p1g);
  412. }
  413. return ncb_1g_get_workspace(get_algorithm(param), param);
  414. }
  415. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  416. ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
  417. const NCBKernSizeParam& param) {
  418. if (param.filter_meta.group != 1) {
  419. auto p1g = param;
  420. p1g.filter_meta.group = 1;
  421. return ncb_1g_get_all_algorithms(p1g);
  422. }
  423. return ncb_1g_get_all_algorithms(param);
  424. }
  425. ConvolutionBackwardDataImpl::Algorithm*
  426. ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
  427. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  428. bool reproducible) {
  429. if (param.filter_meta.group != 1) {
  430. auto p1g = param;
  431. p1g.filter_meta.group = 1;
  432. return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
  433. reproducible);
  434. }
  435. return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
  436. reproducible);
  437. }
  438. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  439. Algorithm* algo, const NCBKernSizeParam& param) {
  440. megdnn_assert(param.filter_meta.group == 1);
  441. if (algo->type() == sm_fallback_deconv_algo_type) {
  442. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  443. }
  444. megdnn_assert(algo == &naive_conv_backward_data);
  445. return 0;
  446. }
  447. ConvolutionBackwardDataImpl::ncb_kern_t
  448. ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  449. Algorithm* algo, const NCBKernSizeParam& param) {
  450. megdnn_assert(param.filter_meta.group == 1);
  451. if (algo->type() == sm_fallback_deconv_algo_type) {
  452. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  453. }
  454. if (algo == &naive_conv_backward_data) {
  455. #define cb(_dt) \
  456. do { \
  457. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  458. MIDOUT_BEGIN(megdnn_fb_convbwd_float, \
  459. midout_iv(DTypeTrait<_dt>::enumv)) { \
  460. using ctype = DTypeTrait<_dt>::ctype; \
  461. return kern_naive<ctype, ctype, ctype>; \
  462. } \
  463. MIDOUT_END(); \
  464. } \
  465. } while (0);
  466. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
  467. #undef cb
  468. #define cb(dt_src, dt_dst) \
  469. do { \
  470. if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
  471. param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
  472. param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
  473. return kern_naive<DTypeTrait<dt_src>::ctype, \
  474. DTypeTrait<dt_src>::ctype, \
  475. DTypeTrait<dt_dst>::ctype>; \
  476. } \
  477. } while (0);
  478. cb(dtype::Int8, dtype::Int32) cb(dtype::Quantized8Asymm,
  479. dtype::QuantizedS32)
  480. cb(dtype::QuantizedS8, dtype::QuantizedS32) megdnn_throw(
  481. "unsupported data type on ConvolutionBackwardData");
  482. #undef cb
  483. }
  484. megdnn_throw(
  485. megdnn_mangle("no suitable ConvolutionBackwardData algorithm"));
  486. }
  487. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  488. const NCBKernSizeParam& param) {
  489. auto&& fm = param.filter_meta;
  490. auto OC = fm.ocpg, IC = fm.icpg;
  491. return (OC * IC >= 32) ||
  492. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  493. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  494. }
  495. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  496. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
  497. const NCBKernSizeParam& param) {
  498. std::vector<Algorithm*> ret;
  499. ret.reserve(2);
  500. ret.push_back(&naive_conv_backward_data);
  501. // insert from lowest to highest preference
  502. AlgoBase* cand[2] = {nullptr};
  503. if (param.filter_meta.group == 1 && param.filter_meta.dilation[0] == 1 &&
  504. param.filter_meta.dilation[1] == 1) {
  505. // we currently only have non-dilated algos
  506. if (param.filter_type.enumv() == DTypeEnum::Float32) {
  507. if (is_matrix_mul_preferred(param)) {
  508. cand[0] = &sm_algo_pack.direct;
  509. cand[1] = &sm_algo_pack.matmul;
  510. } else {
  511. cand[0] = &sm_algo_pack.matmul;
  512. cand[1] = &sm_algo_pack.direct;
  513. }
  514. } else {
  515. cand[0] = &sm_algo_pack.matmul;
  516. }
  517. }
  518. for (auto i : cand) {
  519. if (i && i->usable(this, param)) {
  520. ret.push_back(i);
  521. }
  522. }
  523. std::reverse(ret.begin(), ret.end());
  524. return ret;
  525. }
  526. ConvolutionBackwardDataImpl::Algorithm*
  527. ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
  528. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  529. bool reproducible) {
  530. for (auto i : ncb_1g_get_all_algorithms(param)) {
  531. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  532. if (reproducible) {
  533. if (i->is_reproducible()) {
  534. return i;
  535. }
  536. } else {
  537. return i;
  538. }
  539. }
  540. }
  541. megdnn_assert(0,
  542. "no suitable algorithm found within given workspace limit");
  543. }
  544. ConvolutionBackwardDataImpl::Algorithm*
  545. ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
  546. if (auto set = execution_policy().algorithm) {
  547. return set;
  548. }
  549. if (!m_prev_selected_algo ||
  550. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  551. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  552. param, std::numeric_limits<size_t>::max());
  553. m_prev_selected_algo_sizep = param;
  554. }
  555. return m_prev_selected_algo;
  556. }
  557. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  558. // fallback version 0
  559. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  560. }
  561. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)