You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. /**
  2. * \file dnn/src/fallback/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/fallback/convolution/opr_impl.h"
  13. #include "src/common/algo_chooser.h"
  14. #include "src/common/metahelper.h"
  15. #include "src/common/opr_delegate.h"
  16. #include "src/common/utils.h"
  17. #include "src/fallback/convolution/algos.h"
  18. #include "src/fallback/convolution/run_conv.h"
  19. #include "src/naive/convolution/helper.h"
  20. #include "src/naive/handle.h"
  21. #include "midout.h"
  22. #include <cstring>
  23. MIDOUT_DECL(megdnn_fb_conv_float)
  24. MIDOUT_DECL(megdnn_fb_convbwd_float)
  25. using namespace megdnn;
  26. using namespace fallback;
  27. namespace {
  28. class NaiveConvolutionBackwardData final
  29. : public megdnn::ConvolutionBackwardData::Algorithm {
  30. bool is_reproducible() const override { return true; }
  31. const char* name() const override { return "NCBD"; }
  32. };
  33. NaiveConvolutionBackwardData naive_conv_backward_data;
  34. uint8_t fallback_deconv_algo_type_storage;
  35. uint8_t fallback_conv_algo_type_storage;
  36. template <typename T>
  37. void incr_ptr(T*& dst, ptrdiff_t delta) {
  38. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  39. }
  40. } // namespace
  41. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  42. AlgoFallback algo_fallback;
  43. AlgoNaive algo_naive;
  44. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  45. public:
  46. AlgoPack() {
  47. static CpuOprDelegationStorage<1> storage;
  48. auto conv_bias_opr = storage.get<ConvBias, 0>();
  49. auto&& conv_bias_algo =
  50. static_cast<ConvBiasImpl*>(conv_bias_opr)->algo_pack();
  51. for (auto&& algorithm : conv_bias_algo) {
  52. // fallback algo
  53. refhold.emplace_back(new AlgoDefault(algorithm));
  54. all_algos.emplace_back(refhold.back().get());
  55. }
  56. all_algos.emplace_back(&algo_fallback);
  57. all_algos.emplace_back(&algo_naive);
  58. }
  59. SmallVector<AlgoBase*> all_algos;
  60. };
  61. void* const ConvolutionImpl::sm_fallback_conv_algo_type =
  62. &fallback_conv_algo_type_storage;
  63. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
  64. static AlgoPack sl_algo_pack;
  65. return sl_algo_pack.all_algos;
  66. }
  67. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  68. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  69. }
  70. #define NCB_ALGO_FUNC(name, algo, param) \
  71. static_cast<AlgoBase*>(algo)->name(param)
  72. void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  73. _megdnn_tensor_out dst,
  74. const PreprocessedFilter* preprocessed_filter,
  75. _megdnn_workspace workspace) {
  76. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  77. workspace);
  78. ConvolutionImpl::Algorithm* algo = get_algorithm(fparam, workspace.size);
  79. if (!is_naive_algo(algo) &&
  80. NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
  81. exec_with_ncb_kern(fparam, algo);
  82. } else {
  83. naive::ConvolutionForwardImpl::exec(src, filter, dst,
  84. preprocessed_filter, workspace);
  85. }
  86. }
  87. void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout,
  88. _megdnn_tensor_in filter,
  89. const TensorLayout& dst_layout,
  90. PreprocessedFilter* preprocessed_filter,
  91. _megdnn_workspace workspace) {
  92. //! exec_preprocess currently only support preprocess weights before exec,
  93. //! src/dst will be ignored, just set to nullptr
  94. TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
  95. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  96. workspace);
  97. //! should not pass workspace_size limit otherwise can not find match algo
  98. ConvolutionImpl::Algorithm* algo = get_algorithm(fparam);
  99. if (!is_naive_algo(algo) && NCB_ALGO_FUNC(get_preprocess_workspace, algo,
  100. fparam) <= workspace.size) {
  101. exec_preprocess_with_ncb_kern(fparam, algo);
  102. } else {
  103. naive::ConvolutionForwardImpl::exec_preprocess(
  104. src_layout, filter, dst_layout, preprocessed_filter, workspace);
  105. }
  106. }
  107. size_t ConvolutionImpl::get_workspace_in_bytes(
  108. const TensorLayout& src, const TensorLayout& filter,
  109. const TensorLayout& dst,
  110. const PreprocessedFilter* preprocessed_filter) {
  111. auto fparam =
  112. make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
  113. Algorithm* algo = get_algorithm(fparam);
  114. if (is_naive_algo(algo)) {
  115. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  116. src, filter, dst, preprocessed_filter);
  117. } else {
  118. return NCB_ALGO_FUNC(get_workspace, algo, fparam);
  119. }
  120. }
  121. size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
  122. const TensorLayout& src, const TensorLayout& filter,
  123. const TensorLayout& dst) {
  124. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  125. Algorithm* algo = get_algorithm(fparam);
  126. if (is_naive_algo(algo)) {
  127. return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes(
  128. src, filter, dst);
  129. } else {
  130. return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
  131. }
  132. }
  133. SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
  134. const TensorLayout& src, const TensorLayout& filter,
  135. const TensorLayout& dst){
  136. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  137. Algorithm* algo = get_algorithm(fparam);
  138. if (is_naive_algo(algo)) {
  139. return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout(
  140. src, filter, dst);
  141. } else {
  142. return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
  143. }
  144. }
  145. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  146. const TensorLayout& src, const TensorLayout& filter,
  147. const TensorLayout& dst) {
  148. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  149. auto ret = get_all_algorithms_with_ncb(fparam);
  150. if (ret.empty()) {
  151. return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter,
  152. dst);
  153. }
  154. return ret;
  155. }
  156. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  157. const TensorLayout& src, const TensorLayout& filter,
  158. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  159. bool reproducible) {
  160. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  161. auto result = get_algorithm_heuristic_with_ncb(
  162. fparam, workspace_limit_in_bytes, reproducible);
  163. if (result == nullptr) {
  164. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  165. src, filter, dst, workspace_limit_in_bytes, reproducible);
  166. }
  167. return result;
  168. }
  169. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  170. const TensorLayout& src, const TensorLayout& filter,
  171. const TensorLayout& dst,
  172. const PreprocessedFilter* preprocessed_filter) {
  173. auto safe_u32 = [](size_t v) -> uint32_t {
  174. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  175. "value too large: %zu", v);
  176. return v;
  177. };
  178. size_t spatial_pos;
  179. if (param().format == Param::Format::NCHW88 ||
  180. param().format == Param::Format::NCHW8 ||
  181. param().format == Param::Format::NCHW4 ||
  182. param().format == Param::Format::NCHW44_DOT ||
  183. param().format == Param::Format::NCHW44) {
  184. spatial_pos = 2;
  185. } else if (param().format == Param::Format::NCHW ||
  186. param().format == Param::Format::NCHW_WINOGRAD) {
  187. spatial_pos = 2;
  188. } else if (param().format == Param::Format::NHWC) {
  189. spatial_pos = 1;
  190. } else {
  191. megdnn_assert(0, "invalid conv format %d",
  192. static_cast<int>(param().format));
  193. }
  194. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  195. ->megcore_dispatcher()
  196. ->nr_threads();
  197. return {safe_u32(src[0]),
  198. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  199. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  200. check_layout_fwd(src, filter, dst),
  201. src.dtype,
  202. filter.dtype,
  203. dst.dtype,
  204. src.stride[0],
  205. dst.stride[0],
  206. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  207. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  208. param().compute_mode,
  209. nr_threads,
  210. preprocessed_filter};
  211. }
  212. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  213. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  214. const PreprocessedFilter* preprocessed_filter,
  215. _megdnn_workspace workspace) {
  216. NCBKernParam ret;
  217. static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
  218. src.layout, filter.layout, dst.layout, preprocessed_filter);
  219. ret.src_ptr = src.raw_ptr;
  220. ret.filter_ptr = filter.raw_ptr;
  221. ret.dst_ptr = dst.raw_ptr;
  222. ret.workspace_ptr = workspace.raw_ptr;
  223. ret.workspace_size = workspace.size;
  224. return ret;
  225. }
  226. void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
  227. Algorithm* algo) {
  228. auto kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
  229. auto fallback_handle = handle();
  230. for (auto kernel : kerns) {
  231. megdnn_assert(
  232. param.filter_meta.format == Param::Format::NCHW ||
  233. param.filter_meta.format == Param::Format::NHWC ||
  234. param.filter_meta.format == Param::Format::NCHW88 ||
  235. param.filter_meta.format == Param::Format::NCHW44,
  236. "invalid conv format");
  237. auto run = [param, kernel](size_t index, size_t thread_id) {
  238. CpuNDRange ndrange_id(kernel.global_size, index);
  239. kernel.kern(param, {thread_id, ndrange_id});
  240. };
  241. static_cast<naive::HandleImpl*>(fallback_handle)
  242. ->dispatch_kern(run, kernel.global_size.total_size());
  243. }
  244. }
  245. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
  246. Algorithm* algo) {
  247. auto kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
  248. auto fallback_handle = handle();
  249. for (auto kernel : kerns) {
  250. megdnn_assert(
  251. param.filter_meta.format == Param::Format::NCHW ||
  252. param.filter_meta.format == Param::Format::NHWC ||
  253. param.filter_meta.format == Param::Format::NCHW88 ||
  254. param.filter_meta.format == Param::Format::NCHW44,
  255. "invalid conv format");
  256. auto run = [param, kernel](size_t index, size_t thread_id) {
  257. CpuNDRange ndrange_id(kernel.global_size, index);
  258. kernel.kern(param, {thread_id, ndrange_id});
  259. };
  260. static_cast<naive::HandleImpl*>(fallback_handle)
  261. ->dispatch_kern(run, kernel.global_size.total_size());
  262. }
  263. }
  264. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  265. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  266. bool reproducible) {
  267. for (auto i : get_all_algorithms_with_ncb(param)) {
  268. bool usable_reproducible =
  269. static_cast<AlgoBase*>(i)->usable_reproducible(
  270. param, AlgoSelectionStrategy::HEURISTIC, reproducible);
  271. if (usable_reproducible && NCB_ALGO_FUNC(get_workspace, i, param) <=
  272. workspace_limit_in_bytes) {
  273. return i;
  274. }
  275. }
  276. return nullptr;
  277. }
  278. std::vector<ConvolutionImpl::Algorithm*>
  279. ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  280. std::vector<Algorithm*> ret;
  281. std::vector<Algorithm*> prefer_algos;
  282. for (auto&& i : algo_pack()) {
  283. if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
  284. if (i->is_preferred(param)) {
  285. prefer_algos.push_back(i);
  286. } else {
  287. ret.push_back(i);
  288. }
  289. }
  290. }
  291. std::reverse(prefer_algos.begin(), prefer_algos.end());
  292. //! Prefer algo inserted from begin
  293. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  294. return ret;
  295. }
  296. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  297. const NCBKernSizeParam& param, size_t workspace_size) {
  298. if (auto set = execution_policy().algorithm) {
  299. return set;
  300. }
  301. if (!m_prev_selected_algo ||
  302. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  303. m_prev_selected_algo =
  304. get_algorithm_heuristic_with_ncb(param, workspace_size);
  305. m_prev_selected_algo_sizep = param;
  306. }
  307. return m_prev_selected_algo;
  308. }
  309. const char* ConvolutionImpl::get_algorithm_set_name() const {
  310. // fallback version 0
  311. return "F0";
  312. }
  313. /* ===================== ConvolutionBackwardData ===================== */
  314. void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type =
  315. &fallback_deconv_algo_type_storage;
  316. struct ConvolutionBackwardDataImpl::AlgoPack {
  317. AlgoDirect direct;
  318. AlgoMatrixMul matmul;
  319. };
  320. ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
  321. void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
  322. _megdnn_tensor_in diff,
  323. _megdnn_tensor_out grad,
  324. _megdnn_workspace workspace) {
  325. if (param().format == param::Convolution::Format::NHWCD4 ||
  326. param().format == param::Convolution::Format::NCHW4) {
  327. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad,
  328. workspace);
  329. }
  330. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  331. return exec_with_ncb_kern(fparam);
  332. }
  333. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  334. const TensorLayout& filter, const TensorLayout& diff,
  335. const TensorLayout& grad) {
  336. if (param().format == param::Convolution::Format::NHWCD4 ||
  337. param().format == param::Convolution::Format::NCHW4) {
  338. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  339. filter, diff, grad);
  340. }
  341. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  342. return get_workspace_with_ncb(fparam);
  343. }
  344. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  345. ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
  346. const TensorLayout& diff,
  347. const TensorLayout& grad) {
  348. if (param().format == param::Convolution::Format::NHWCD4 ||
  349. param().format == param::Convolution::Format::NCHW4) {
  350. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  351. filter, diff, grad);
  352. }
  353. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  354. auto ret = get_all_algorithms_with_ncb(fparam);
  355. megdnn_assert(!ret.empty(), "no usable conv fwd algorithm");
  356. return ret;
  357. }
  358. ConvolutionBackwardDataImpl::Algorithm*
  359. ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  360. const TensorLayout& filter, const TensorLayout& diff,
  361. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  362. bool reproducible) {
  363. if (param().format == param::Convolution::Format::NHWCD4 ||
  364. param().format == param::Convolution::Format::NCHW4) {
  365. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  366. filter, diff, grad, workspace_limit_in_bytes, reproducible);
  367. }
  368. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  369. return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
  370. reproducible);
  371. }
  372. ConvolutionBackwardDataImpl::NCBKernSizeParam
  373. ConvolutionBackwardDataImpl::make_ncb_kern_size_param(
  374. const TensorLayout& filter, const TensorLayout& diff,
  375. const TensorLayout& grad) {
  376. auto safe_u32 = [](size_t v) -> uint32_t {
  377. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  378. "value too large: %zu", v);
  379. return v;
  380. };
  381. size_t spatial_pos;
  382. if (param().format == Param::Format::NCHW) {
  383. spatial_pos = 2;
  384. } else {
  385. megdnn_assert(param().format == Param::Format::NHWC,
  386. "invalid conv format");
  387. spatial_pos = 1;
  388. }
  389. auto grad_fwd = grad;
  390. auto filter_fwd = filter;
  391. auto diff_fwd = diff;
  392. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  393. return {
  394. safe_u32(diff[0]),
  395. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  396. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  397. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  398. diff.dtype,
  399. filter.dtype,
  400. grad.dtype,
  401. diff,
  402. filter,
  403. grad,
  404. diff.stride[0],
  405. grad.stride[0],
  406. 0,
  407. 0,
  408. 0,
  409. param().compute_mode,
  410. };
  411. }
  412. ConvolutionBackwardDataImpl::NCBKernParam
  413. ConvolutionBackwardDataImpl::make_ncb_kern_param(_megdnn_tensor_in filter,
  414. _megdnn_tensor_in diff,
  415. _megdnn_tensor_out grad,
  416. _megdnn_workspace workspace) {
  417. NCBKernParam ret;
  418. static_cast<NCBKernSizeParam&>(ret) =
  419. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  420. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  421. megdnn_assert(workspace.size >= required_workspace_in_bytes,
  422. "required workspace: %zu; provided workspace: %zu",
  423. required_workspace_in_bytes, workspace.size);
  424. ret.filter_ptr = filter.raw_ptr;
  425. ret.diff_ptr = diff.raw_ptr;
  426. ret.grad_ptr = grad.raw_ptr;
  427. ret.workspace_ptr = workspace.raw_ptr;
  428. ret.workspace_size = workspace.size;
  429. return ret;
  430. }
  431. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(
  432. const NCBKernParam& param) {
  433. auto p1g = param;
  434. auto group = p1g.filter_meta.group;
  435. p1g.filter_meta.group = 1;
  436. auto algo = get_algorithm(p1g);
  437. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  438. if (algo == &naive_conv_backward_data || group == 1) {
  439. auto run = [kptr, param]() { kptr(param); };
  440. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  441. } else {
  442. megdnn_assert(p1g.filter_meta.format == Param::Format::NCHW ||
  443. p1g.filter_meta.format == Param::Format::NHWC,
  444. "invalid conv format");
  445. auto run = [kptr, p1g_orig = p1g, group]() {
  446. auto p1g = p1g_orig;
  447. ptrdiff_t istrd, fstrd, ostrd;
  448. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  449. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  450. p1g.filter_type.size();
  451. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  452. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  453. p1g.diff_extra_mem_size =
  454. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  455. p1g.filter_extra_mem_size =
  456. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  457. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  458. p1g.filter_type.size();
  459. p1g.grad_extra_mem_size =
  460. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  461. if (p1g.filter_meta.format == Param::Format::NCHW) {
  462. istrd *= p1g.isz[0] * p1g.isz[1];
  463. ostrd *= p1g.osz[0] * p1g.osz[1];
  464. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  465. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  466. } else {
  467. // must be NHWC. No action performed.
  468. }
  469. for (size_t i = 0; i < group; ++i) {
  470. kptr(p1g);
  471. incr_ptr(p1g.diff_ptr, istrd);
  472. incr_ptr(p1g.filter_ptr, fstrd);
  473. incr_ptr(p1g.grad_ptr, ostrd);
  474. p1g.diff_extra_mem_size -= istrd;
  475. p1g.filter_extra_mem_size -= fstrd;
  476. p1g.grad_extra_mem_size -= ostrd;
  477. }
  478. };
  479. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  480. }
  481. }
  482. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  483. const NCBKernSizeParam& param) {
  484. if (param.filter_meta.group != 1) {
  485. auto p1g = param;
  486. p1g.filter_meta.group = 1;
  487. return ncb_1g_get_workspace(get_algorithm(p1g), p1g);
  488. }
  489. return ncb_1g_get_workspace(get_algorithm(param), param);
  490. }
  491. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  492. ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
  493. const NCBKernSizeParam& param) {
  494. if (param.filter_meta.group != 1) {
  495. auto p1g = param;
  496. p1g.filter_meta.group = 1;
  497. return ncb_1g_get_all_algorithms(p1g);
  498. }
  499. return ncb_1g_get_all_algorithms(param);
  500. }
  501. ConvolutionBackwardDataImpl::Algorithm*
  502. ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
  503. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  504. bool reproducible) {
  505. if (param.filter_meta.group != 1) {
  506. auto p1g = param;
  507. p1g.filter_meta.group = 1;
  508. return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
  509. reproducible);
  510. }
  511. return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
  512. reproducible);
  513. }
  514. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  515. Algorithm* algo, const NCBKernSizeParam& param) {
  516. megdnn_assert(param.filter_meta.group == 1);
  517. if (algo->type() == sm_fallback_deconv_algo_type) {
  518. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  519. }
  520. megdnn_assert(algo == &naive_conv_backward_data);
  521. return 0;
  522. }
  523. ConvolutionBackwardDataImpl::ncb_kern_t
  524. ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  525. Algorithm* algo, const NCBKernSizeParam& param) {
  526. megdnn_assert(param.filter_meta.group == 1);
  527. if (algo->type() == sm_fallback_deconv_algo_type) {
  528. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  529. }
  530. if (algo == &naive_conv_backward_data) {
  531. #define cb(_dt) \
  532. do { \
  533. if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
  534. MIDOUT_BEGIN(megdnn_fb_convbwd_float, \
  535. midout_iv(DTypeTrait<_dt>::enumv)) { \
  536. using ctype = DTypeTrait<_dt>::ctype; \
  537. return kern_naive<ctype, ctype, ctype>; \
  538. } \
  539. MIDOUT_END(); \
  540. } \
  541. } while (0);
  542. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
  543. #undef cb
  544. #define cb(dt_src, dt_dst) \
  545. do { \
  546. if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
  547. param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
  548. param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
  549. return kern_naive<DTypeTrait<dt_src>::ctype, \
  550. DTypeTrait<dt_src>::ctype, \
  551. DTypeTrait<dt_dst>::ctype>; \
  552. } \
  553. } while (0);
  554. cb(dtype::Int8, dtype::Int32) cb(dtype::Quantized8Asymm,
  555. dtype::QuantizedS32)
  556. cb(dtype::QuantizedS8, dtype::QuantizedS32) megdnn_throw(
  557. "unsupported data type on ConvolutionBackwardData");
  558. #undef cb
  559. }
  560. megdnn_throw(
  561. megdnn_mangle("no suitable ConvolutionBackwardData algorithm"));
  562. }
  563. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  564. const NCBKernSizeParam& param) {
  565. auto&& fm = param.filter_meta;
  566. auto OC = fm.ocpg, IC = fm.icpg;
  567. return (OC * IC >= 32) ||
  568. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  569. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  570. }
  571. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  572. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
  573. const NCBKernSizeParam& param) {
  574. std::vector<Algorithm*> ret;
  575. ret.reserve(2);
  576. ret.push_back(&naive_conv_backward_data);
  577. // insert from lowest to highest preference
  578. AlgoBase* cand[2] = {nullptr};
  579. if (param.filter_meta.group == 1 && param.filter_meta.dilation[0] == 1 &&
  580. param.filter_meta.dilation[1] == 1) {
  581. // we currently only have non-dilated algos
  582. if (param.filter_type.enumv() == DTypeEnum::Float32) {
  583. if (is_matrix_mul_preferred(param)) {
  584. cand[0] = &sm_algo_pack.direct;
  585. cand[1] = &sm_algo_pack.matmul;
  586. } else {
  587. cand[0] = &sm_algo_pack.matmul;
  588. cand[1] = &sm_algo_pack.direct;
  589. }
  590. } else {
  591. cand[0] = &sm_algo_pack.matmul;
  592. }
  593. }
  594. for (auto i : cand) {
  595. if (i && i->usable(this, param)) {
  596. ret.push_back(i);
  597. }
  598. }
  599. std::reverse(ret.begin(), ret.end());
  600. return ret;
  601. }
  602. ConvolutionBackwardDataImpl::Algorithm*
  603. ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
  604. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  605. bool reproducible) {
  606. for (auto i : ncb_1g_get_all_algorithms(param)) {
  607. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  608. if (reproducible) {
  609. if (i->is_reproducible()) {
  610. return i;
  611. }
  612. } else {
  613. return i;
  614. }
  615. }
  616. }
  617. megdnn_assert(0,
  618. "no suitable algorithm found within given workspace limit");
  619. }
  620. ConvolutionBackwardDataImpl::Algorithm*
  621. ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
  622. if (auto set = execution_policy().algorithm) {
  623. return set;
  624. }
  625. if (!m_prev_selected_algo ||
  626. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  627. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  628. param, std::numeric_limits<size_t>::max());
  629. m_prev_selected_algo_sizep = param;
  630. }
  631. return m_prev_selected_algo;
  632. }
  633. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  634. // fallback version 0
  635. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  636. }
  637. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台