You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. /**
  2. * \file dnn/src/fallback/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/algo_chooser.h"
  13. #include "src/common/metahelper.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/common/utils.h"
  16. #include "src/fallback/convolution/algos.h"
  17. #include "src/fallback/convolution/opr_impl.h"
  18. #include "src/fallback/convolution/run_conv.h"
  19. #include "src/naive/convolution/helper.h"
  20. #include "src/naive/handle.h"
  21. #include "midout.h"
  22. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  23. #include "src/arm_common/convolution/opr_impl.h"
  24. #endif
  25. #include <cstring>
  26. #include <unordered_map>
  27. MIDOUT_DECL(megdnn_fb_convbwd_float)
  28. using namespace megdnn;
  29. using namespace fallback;
  30. namespace {
  31. template <typename T>
  32. void incr_ptr(T*& dst, ptrdiff_t delta) {
  33. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  34. }
  35. } // namespace
  36. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  37. AlgoFallback algo_fallback;
  38. AlgoNaive algo_naive;
  39. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  40. SmallVector<AlgoBase*> m_all_algos;
  41. AlgoBase::Mapper m_all_algos_map;
  42. public:
  43. AlgoPack() {
  44. static CpuOprDelegationStorage<1> storage;
  45. auto conv_bias_opr = storage.get<ConvBias, 0>();
  46. auto&& conv_bias_algo =
  47. static_cast<ConvBiasImpl*>(conv_bias_opr)->get_all_packed_algo();
  48. for (auto&& algorithm : conv_bias_algo) {
  49. // fallback algo
  50. refhold.emplace_back(new AlgoDefault(algorithm));
  51. m_all_algos.emplace_back(refhold.back().get());
  52. }
  53. m_all_algos.emplace_back(&algo_fallback);
  54. m_all_algos.emplace_back(&algo_naive);
  55. for (auto&& algo : m_all_algos) {
  56. m_all_algos_map.emplace(algo->info().desc, algo);
  57. }
  58. }
  59. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  60. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  61. };
  62. const ConvolutionImpl::AlgoPack& ConvolutionImpl::algo_pack() {
  63. static AlgoPack algo_pack;
  64. return algo_pack;
  65. }
  66. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::get_all_packed_algo() {
  67. return algo_pack().all_algos();
  68. }
  69. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
  70. ConvAlgoTypePack target_type) {
  71. megdnn_assert(nr_type_contain(target_type.data_type),
  72. "ConvBias algo selection only support one type");
  73. SmallVector<ConvolutionImpl::AlgoBase*> algos;
  74. for (auto&& algo : get_all_packed_algo()) {
  75. auto algo_type = algo->get_algo_type();
  76. if (contain_data_type(algo_type.data_type, target_type.data_type) &&
  77. algo_type.algo_category == target_type.algo_category) {
  78. algos.push_back(algo);
  79. }
  80. }
  81. return algos;
  82. }
  83. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  84. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  85. }
  86. #define NCB_ALGO_FUNC(name, algo, param) \
  87. static_cast<AlgoBase*>(algo)->name(param)
  88. void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  89. _megdnn_tensor_out dst,
  90. const PreprocessedFilter* preprocessed_filter,
  91. _megdnn_workspace workspace) {
  92. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  93. workspace);
  94. auto&& algo = get_algorithm(fparam, workspace.size);
  95. if (!is_naive_algo(algo) &&
  96. NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
  97. exec_with_ncb_kern(fparam, algo);
  98. } else {
  99. naive::ConvolutionForwardImpl::exec(src, filter, dst,
  100. preprocessed_filter, workspace);
  101. }
  102. }
  103. void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout,
  104. _megdnn_tensor_in filter,
  105. const TensorLayout& dst_layout,
  106. PreprocessedFilter* preprocessed_filter,
  107. _megdnn_workspace workspace) {
  108. //! exec_preprocess currently only support preprocess weights before exec,
  109. //! src/dst will be ignored, just set to nullptr
  110. TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
  111. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  112. workspace);
  113. //! should not pass workspace_size limit otherwise can not find match algo
  114. auto&& algo = get_algorithm(fparam);
  115. if (!is_naive_algo(algo) &&
  116. NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <=
  117. workspace.size) {
  118. exec_preprocess_with_ncb_kern(fparam, algo);
  119. } else {
  120. naive::ConvolutionForwardImpl::exec_preprocess(
  121. src_layout, filter, dst_layout, preprocessed_filter, workspace);
  122. }
  123. }
  124. size_t ConvolutionImpl::get_workspace_in_bytes(
  125. const TensorLayout& src, const TensorLayout& filter,
  126. const TensorLayout& dst,
  127. const PreprocessedFilter* preprocessed_filter) {
  128. auto fparam =
  129. make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
  130. auto&& algo = get_algorithm(fparam);
  131. if (is_naive_algo(algo)) {
  132. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  133. src, filter, dst, preprocessed_filter);
  134. } else {
  135. return NCB_ALGO_FUNC(get_workspace, algo, fparam);
  136. }
  137. }
  138. size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
  139. const TensorLayout& src, const TensorLayout& filter,
  140. const TensorLayout& dst) {
  141. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  142. auto&& algo = get_algorithm(fparam);
  143. if (is_naive_algo(algo)) {
  144. return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes(
  145. src, filter, dst);
  146. } else {
  147. return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
  148. }
  149. }
  150. SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
  151. const TensorLayout& src, const TensorLayout& filter,
  152. const TensorLayout& dst) {
  153. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  154. auto&& algo = get_algorithm(fparam);
  155. if (is_naive_algo(algo)) {
  156. return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout(
  157. src, filter, dst);
  158. } else {
  159. return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
  160. }
  161. }
  162. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  163. const TensorLayout& src, const TensorLayout& filter,
  164. const TensorLayout& dst) {
  165. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  166. auto ret = get_all_algorithms_with_ncb(fparam);
  167. if (ret.empty()) {
  168. return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter,
  169. dst);
  170. }
  171. return ret;
  172. }
  173. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  174. const TensorLayout& src, const TensorLayout& filter,
  175. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  176. const AlgoAttribute& positive_attr,
  177. const AlgoAttribute& negative_attr) {
  178. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  179. auto result = get_algorithm_heuristic_with_ncb(
  180. fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
  181. if (result == nullptr) {
  182. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  183. src, filter, dst, workspace_limit_in_bytes, positive_attr,
  184. negative_attr);
  185. }
  186. return result;
  187. }
  188. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  189. const TensorLayout& src, const TensorLayout& filter,
  190. const TensorLayout& dst,
  191. const PreprocessedFilter* preprocessed_filter) {
  192. auto safe_u32 = [](size_t v) -> uint32_t {
  193. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  194. "value too large: %zu", v);
  195. return v;
  196. };
  197. size_t spatial_pos;
  198. if (param().format == Param::Format::NCHW88 ||
  199. param().format == Param::Format::NCHW8 ||
  200. param().format == Param::Format::NCHW4 ||
  201. param().format == Param::Format::NCHW44_DOT ||
  202. param().format == Param::Format::NCHW44) {
  203. spatial_pos = 2;
  204. } else if (param().format == Param::Format::NCHW) {
  205. spatial_pos = 2;
  206. } else if (param().format == Param::Format::NHWC) {
  207. spatial_pos = 1;
  208. } else {
  209. megdnn_assert(0, "invalid conv format %d",
  210. static_cast<int>(param().format));
  211. }
  212. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  213. ->megcore_dispatcher()
  214. ->nr_threads();
  215. return {safe_u32(src[0]),
  216. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  217. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  218. check_layout_fwd(src, filter, dst),
  219. src.dtype,
  220. filter.dtype,
  221. dst.dtype,
  222. src.stride[0],
  223. dst.stride[0],
  224. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  225. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  226. param().compute_mode,
  227. nr_threads,
  228. preprocessed_filter};
  229. }
  230. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  231. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  232. const PreprocessedFilter* preprocessed_filter,
  233. _megdnn_workspace workspace) {
  234. NCBKernParam ret;
  235. static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
  236. src.layout, filter.layout, dst.layout, preprocessed_filter);
  237. ret.src_ptr = src.raw_ptr;
  238. ret.filter_ptr = filter.raw_ptr;
  239. ret.dst_ptr = dst.raw_ptr;
  240. ret.workspace_ptr = workspace.raw_ptr;
  241. ret.workspace_size = workspace.size;
  242. return ret;
  243. }
  244. void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
  245. Algorithm* algo) {
  246. auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
  247. auto&& fallback_handle = handle();
  248. for (auto&& kernel : kerns) {
  249. megdnn_assert(
  250. param.filter_meta.format == Param::Format::NCHW ||
  251. param.filter_meta.format == Param::Format::NHWC ||
  252. param.filter_meta.format == Param::Format::NCHW88 ||
  253. param.filter_meta.format == Param::Format::NCHW44 ||
  254. param.filter_meta.format == Param::Format::NCHW44_DOT,
  255. "invalid conv format");
  256. auto run = [param, kernel](size_t index, size_t thread_id) {
  257. CpuNDRange ndrange_id(kernel.global_size, index);
  258. kernel.kern(param, {thread_id, ndrange_id});
  259. };
  260. static_cast<naive::HandleImpl*>(fallback_handle)
  261. ->dispatch_kern(run, kernel.global_size.total_size());
  262. }
  263. }
  264. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
  265. Algorithm* algo) {
  266. auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
  267. auto&& fallback_handle = handle();
  268. for (auto&& kernel : kerns) {
  269. megdnn_assert(
  270. param.filter_meta.format == Param::Format::NCHW ||
  271. param.filter_meta.format == Param::Format::NHWC ||
  272. param.filter_meta.format == Param::Format::NCHW88 ||
  273. param.filter_meta.format == Param::Format::NCHW44 ||
  274. param.filter_meta.format == Param::Format::NCHW44_DOT,
  275. "invalid conv format");
  276. auto run = [param, kernel](size_t index, size_t thread_id) {
  277. CpuNDRange ndrange_id(kernel.global_size, index);
  278. kernel.kern(param, {thread_id, ndrange_id});
  279. };
  280. static_cast<naive::HandleImpl*>(fallback_handle)
  281. ->dispatch_kern(run, kernel.global_size.total_size());
  282. }
  283. }
  284. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  285. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  286. const AlgoAttribute& positive_attr,
  287. const AlgoAttribute& negative_attr) {
  288. auto algo_data_type = param.deduce_algo_data_type();
  289. auto suggest_category_order = suggest_algo_category_order(param);
  290. for (auto category : suggest_category_order) {
  291. auto&& origin_algos = select_algo_type({algo_data_type, category});
  292. ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
  293. for (auto i : origin_algos) {
  294. bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
  295. param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
  296. negative_attr);
  297. if (usable_attribute &&
  298. static_cast<AlgoBase*>(i)->get_workspace(param) <=
  299. workspace_limit_in_bytes) {
  300. //! store the first usable algo if no prefer algo, choose it as
  301. //! the target algo
  302. if (!heuristic_algo) {
  303. heuristic_algo = i;
  304. }
  305. //! choose the first prefer algo
  306. if (i->is_preferred(param)) {
  307. return i;
  308. }
  309. }
  310. }
  311. if (heuristic_algo) {
  312. return heuristic_algo;
  313. }
  314. }
  315. return nullptr;
  316. }
  317. std::vector<ConvolutionImpl::Algorithm*>
  318. ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  319. std::vector<Algorithm*> ret;
  320. std::vector<Algorithm*> prefer_algos;
  321. for (auto&& i : get_all_packed_algo()) {
  322. if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
  323. if (i->is_preferred(param)) {
  324. prefer_algos.push_back(i);
  325. } else {
  326. ret.push_back(i);
  327. }
  328. }
  329. }
  330. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  331. return ret;
  332. }
  333. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
  334. const AlgorithmDesc& desc) {
  335. if (!desc.valid()) {
  336. return nullptr;
  337. } else {
  338. switch (desc.handle_type) {
  339. case Handle::HandleType::FALLBACK: {
  340. const auto& map = algo_pack().all_algos_map();
  341. megdnn_assert(map.find(desc) != map.end());
  342. return map.at(desc);
  343. }
  344. case Handle::HandleType::NAIVE: {
  345. auto algo = static_cast<naive::HandleImpl*>(handle())
  346. ->default_conv_fwd_algo();
  347. megdnn_assert(algo->info().desc == desc);
  348. return algo;
  349. }
  350. default:
  351. megdnn_throw("Unknown handle type");
  352. return nullptr;
  353. }
  354. }
  355. }
  356. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  357. const NCBKernSizeParam& param, size_t workspace_size) {
  358. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  359. return algo;
  360. }
  361. if (!m_prev_selected_algo ||
  362. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  363. m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
  364. param, workspace_size, AlgoAttribute::DEFAULT,
  365. AlgoAttribute::DEFAULT);
  366. m_prev_selected_algo_sizep = param;
  367. }
  368. return m_prev_selected_algo;
  369. }
  370. SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
  371. const NCBKernSizeParam& param) const {
  372. static CpuOprDelegationStorage<1> storage;
  373. auto conv_bias_opr = storage.get<ConvBias, 0>();
  374. auto conv_bias_param =
  375. ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
  376. return static_cast<ConvBiasImpl*>(conv_bias_opr)
  377. ->suggest_algo_category_order(conv_bias_param);
  378. }
  379. const char* ConvolutionImpl::get_algorithm_set_name() const {
  380. // fallback version 0
  381. return "F0";
  382. }
  383. ConvolutionImpl::AlgoDataType
  384. ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
  385. if (src_type.enumv() == DTypeEnum::Float32) {
  386. return ConvolutionImpl::AlgoDataType::FLOAT32;
  387. #if !MEGDNN_DISABLE_FLOAT16
  388. } else if (src_type.enumv() == DTypeEnum::Float16) {
  389. return ConvolutionImpl::AlgoDataType::FLOAT16;
  390. #endif
  391. } else if (src_type.enumv() == DTypeEnum::Int8 ||
  392. src_type.enumv() == DTypeEnum::QuantizedS8) {
  393. if (dst_type.enumv() == DTypeEnum::Int16) {
  394. return ConvolutionImpl::AlgoDataType::INT8X8X16;
  395. } else {
  396. return ConvolutionImpl::AlgoDataType::QINT8X8X32;
  397. }
  398. } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  399. return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
  400. } else {
  401. megdnn_throw(ssprintf("not support data type of %s * %s -> %s\n",
  402. src_type.name(), filter_type.name(),
  403. dst_type.name()));
  404. }
  405. }
  406. /* ===================== ConvolutionBackwardData ===================== */
  407. class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
  408. AlgoNaive algo_naive;
  409. AlgoDirect algo_direct;
  410. AlgoMatrixMul algo_matmul;
  411. SmallVector<AlgoBase*> m_all_algos;
  412. AlgoBase::Mapper m_all_algos_map;
  413. public:
  414. AlgoPack() {
  415. m_all_algos.emplace_back(&algo_matmul);
  416. m_all_algos.emplace_back(&algo_direct);
  417. m_all_algos.emplace_back(&algo_naive);
  418. for (auto&& algo : m_all_algos) {
  419. m_all_algos_map.emplace(algo->info().desc, algo);
  420. }
  421. }
  422. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  423. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  424. };
  425. const ConvolutionBackwardDataImpl::AlgoPack&
  426. ConvolutionBackwardDataImpl::algo_pack() {
  427. static AlgoPack algo_pack;
  428. return algo_pack;
  429. }
  430. SmallVector<ConvolutionBackwardDataImpl::AlgoBase*>
  431. ConvolutionBackwardDataImpl::get_all_packed_algo() {
  432. return algo_pack().all_algos();
  433. }
  434. void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
  435. _megdnn_tensor_in diff,
  436. _megdnn_tensor_out grad,
  437. _megdnn_workspace workspace) {
  438. if (param().format == param::Convolution::Format::NHWCD4 ||
  439. param().format == param::Convolution::Format::NCHW4) {
  440. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad,
  441. workspace);
  442. }
  443. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  444. return exec_with_ncb_kern(fparam);
  445. }
  446. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  447. const TensorLayout& filter, const TensorLayout& diff,
  448. const TensorLayout& grad) {
  449. if (param().format == param::Convolution::Format::NHWCD4 ||
  450. param().format == param::Convolution::Format::NCHW4) {
  451. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  452. filter, diff, grad);
  453. }
  454. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  455. return get_workspace_with_ncb(fparam);
  456. }
  457. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  458. ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
  459. const TensorLayout& diff,
  460. const TensorLayout& grad) {
  461. if (param().format == param::Convolution::Format::NHWCD4 ||
  462. param().format == param::Convolution::Format::NCHW4) {
  463. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  464. filter, diff, grad);
  465. }
  466. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  467. auto ret = get_all_algorithms_with_ncb(fparam);
  468. megdnn_assert(!ret.empty(), "no usable conv fwd algorithm");
  469. return ret;
  470. }
  471. ConvolutionBackwardDataImpl::Algorithm*
  472. ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  473. const TensorLayout& filter, const TensorLayout& diff,
  474. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  475. const AlgoAttribute& positive_attr,
  476. const AlgoAttribute& negative_attr) {
  477. if (param().format == param::Convolution::Format::NHWCD4 ||
  478. param().format == param::Convolution::Format::NCHW4) {
  479. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  480. filter, diff, grad, workspace_limit_in_bytes, positive_attr,
  481. negative_attr);
  482. }
  483. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  484. return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
  485. positive_attr, negative_attr);
  486. }
  487. ConvolutionBackwardDataImpl::NCBKernSizeParam
  488. ConvolutionBackwardDataImpl::make_ncb_kern_size_param(
  489. const TensorLayout& filter, const TensorLayout& diff,
  490. const TensorLayout& grad) {
  491. auto safe_u32 = [](size_t v) -> uint32_t {
  492. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  493. "value too large: %zu", v);
  494. return v;
  495. };
  496. size_t spatial_pos;
  497. if (param().format == Param::Format::NCHW) {
  498. spatial_pos = 2;
  499. } else {
  500. megdnn_assert(param().format == Param::Format::NHWC,
  501. "invalid conv format");
  502. spatial_pos = 1;
  503. }
  504. auto grad_fwd = grad;
  505. auto filter_fwd = filter;
  506. auto diff_fwd = diff;
  507. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  508. return {
  509. safe_u32(diff[0]),
  510. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  511. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  512. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  513. diff.dtype,
  514. filter.dtype,
  515. grad.dtype,
  516. diff,
  517. filter,
  518. grad,
  519. diff.stride[0],
  520. grad.stride[0],
  521. 0,
  522. 0,
  523. 0,
  524. param().compute_mode,
  525. };
  526. }
  527. ConvolutionBackwardDataImpl::NCBKernParam
  528. ConvolutionBackwardDataImpl::make_ncb_kern_param(_megdnn_tensor_in filter,
  529. _megdnn_tensor_in diff,
  530. _megdnn_tensor_out grad,
  531. _megdnn_workspace workspace) {
  532. NCBKernParam ret;
  533. static_cast<NCBKernSizeParam&>(ret) =
  534. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  535. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  536. megdnn_assert(workspace.size >= required_workspace_in_bytes,
  537. "required workspace: %zu; provided workspace: %zu",
  538. required_workspace_in_bytes, workspace.size);
  539. ret.filter_ptr = filter.raw_ptr;
  540. ret.diff_ptr = diff.raw_ptr;
  541. ret.grad_ptr = grad.raw_ptr;
  542. ret.workspace_ptr = workspace.raw_ptr;
  543. ret.workspace_size = workspace.size;
  544. return ret;
  545. }
  546. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(
  547. const NCBKernParam& param) {
  548. auto p1g = param;
  549. auto group = p1g.filter_meta.group;
  550. p1g.filter_meta.group = 1;
  551. auto&& algo = get_algorithm(p1g);
  552. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  553. if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
  554. auto run = [kptr, param]() { kptr(param); };
  555. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  556. } else {
  557. megdnn_assert(p1g.filter_meta.format == Param::Format::NCHW ||
  558. p1g.filter_meta.format == Param::Format::NHWC,
  559. "invalid conv format");
  560. auto run = [kptr, p1g_orig = p1g, group]() {
  561. auto p1g = p1g_orig;
  562. ptrdiff_t istrd, fstrd, ostrd;
  563. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  564. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  565. p1g.filter_type.size();
  566. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  567. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  568. p1g.diff_extra_mem_size =
  569. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  570. p1g.filter_extra_mem_size =
  571. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  572. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  573. p1g.filter_type.size();
  574. p1g.grad_extra_mem_size =
  575. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  576. if (p1g.filter_meta.format == Param::Format::NCHW) {
  577. istrd *= p1g.isz[0] * p1g.isz[1];
  578. ostrd *= p1g.osz[0] * p1g.osz[1];
  579. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  580. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  581. } else {
  582. // must be NHWC. No action performed.
  583. }
  584. for (size_t i = 0; i < group; ++i) {
  585. kptr(p1g);
  586. incr_ptr(p1g.diff_ptr, istrd);
  587. incr_ptr(p1g.filter_ptr, fstrd);
  588. incr_ptr(p1g.grad_ptr, ostrd);
  589. p1g.diff_extra_mem_size -= istrd;
  590. p1g.filter_extra_mem_size -= fstrd;
  591. p1g.grad_extra_mem_size -= ostrd;
  592. }
  593. };
  594. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  595. }
  596. }
  597. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  598. const NCBKernSizeParam& param) {
  599. if (param.filter_meta.group != 1) {
  600. auto p1g = param;
  601. p1g.filter_meta.group = 1;
  602. auto algo = get_algorithm(p1g);
  603. return ncb_1g_get_workspace(algo, p1g);
  604. }
  605. auto algo = get_algorithm(param);
  606. return ncb_1g_get_workspace(algo, param);
  607. }
  608. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  609. ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
  610. const NCBKernSizeParam& param) {
  611. if (param.filter_meta.group != 1) {
  612. auto p1g = param;
  613. p1g.filter_meta.group = 1;
  614. return ncb_1g_get_all_algorithms(p1g);
  615. }
  616. return ncb_1g_get_all_algorithms(param);
  617. }
  618. ConvolutionBackwardDataImpl::Algorithm*
  619. ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
  620. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  621. const AlgoAttribute& positive_attr,
  622. const AlgoAttribute& negative_attr) {
  623. if (param.filter_meta.group != 1) {
  624. auto p1g = param;
  625. p1g.filter_meta.group = 1;
  626. return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
  627. positive_attr, negative_attr);
  628. }
  629. return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
  630. positive_attr, negative_attr);
  631. }
  632. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  633. Algorithm* algo, const NCBKernSizeParam& param) {
  634. megdnn_assert(param.filter_meta.group == 1);
  635. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  636. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  637. }
  638. return 0;
  639. }
  640. ConvolutionBackwardDataImpl::ncb_kern_t
  641. ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  642. Algorithm* algo, const NCBKernSizeParam& param) {
  643. megdnn_assert(param.filter_meta.group == 1);
  644. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  645. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  646. }
  647. megdnn_throw("no suitable ConvolutionBackwardData algorithm");
  648. }
  649. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  650. const NCBKernSizeParam& param) {
  651. auto&& fm = param.filter_meta;
  652. auto OC = fm.ocpg, IC = fm.icpg;
  653. return (OC * IC >= 32) ||
  654. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  655. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  656. }
  657. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  658. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
  659. const NCBKernSizeParam& param) {
  660. std::vector<Algorithm*> ret;
  661. std::vector<Algorithm*> prefer_algos;
  662. for (auto&& i : get_all_packed_algo()) {
  663. if (i->usable(this, param)) {
  664. if (i->is_preferred(param)) {
  665. prefer_algos.push_back(i);
  666. } else {
  667. ret.push_back(i);
  668. }
  669. }
  670. }
  671. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  672. return ret;
  673. }
  674. ConvolutionBackwardDataImpl::Algorithm*
  675. ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
  676. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  677. const AlgoAttribute& positive_attr,
  678. const AlgoAttribute& negative_attr) {
  679. for (auto i : ncb_1g_get_all_algorithms(param)) {
  680. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  681. if (i->contain_attribute_all(positive_attr) &&
  682. !i->contain_attribute_any(negative_attr)) {
  683. return i;
  684. }
  685. }
  686. }
  687. megdnn_assert(0,
  688. "no suitable algorithm found within given workspace limit");
  689. }
  690. ConvolutionBackwardDataImpl::Algorithm*
  691. ConvolutionBackwardDataImpl::get_algorithm_from_desc(
  692. const AlgorithmDesc& desc) {
  693. if (!desc.valid()) {
  694. return nullptr;
  695. } else {
  696. switch (desc.handle_type) {
  697. case Handle::HandleType::FALLBACK: {
  698. const auto& map = algo_pack().all_algos_map();
  699. megdnn_assert(map.find(desc) != map.end());
  700. return map.at(desc);
  701. }
  702. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  703. case Handle::HandleType::ARM_COMMON:
  704. case Handle::HandleType::AARCH64:
  705. case Handle::HandleType::ARMV7:
  706. return arm_common::ConvolutionBackwardDataImpl::
  707. get_algo_from_desc(desc);
  708. #endif
  709. case Handle::HandleType::NAIVE: {
  710. auto algo = static_cast<naive::HandleImpl*>(handle())
  711. ->default_conv_bwd_data_algo();
  712. megdnn_assert(algo->info().desc == desc);
  713. return algo;
  714. }
  715. default:
  716. megdnn_throw("Unknown handle type");
  717. return nullptr;
  718. }
  719. }
  720. }
  721. ConvolutionBackwardDataImpl::Algorithm*
  722. ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
  723. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  724. return algo;
  725. }
  726. if (!m_prev_selected_algo ||
  727. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  728. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  729. param, std::numeric_limits<size_t>::max(),
  730. AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  731. m_prev_selected_algo_sizep = param;
  732. }
  733. return m_prev_selected_algo;
  734. }
  735. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  736. // fallback version 0
  737. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  738. }
  739. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台