You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817
  1. /**
  2. * \file dnn/src/fallback/convolution/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/algo_chooser.h"
  13. #include "src/common/metahelper.h"
  14. #include "src/common/opr_delegate.h"
  15. #include "src/common/utils.h"
  16. #include "src/fallback/convolution/algos.h"
  17. #include "src/fallback/convolution/opr_impl.h"
  18. #include "src/fallback/convolution/run_conv.h"
  19. #include "src/naive/convolution/helper.h"
  20. #include "src/naive/handle.h"
  21. #include "midout.h"
  22. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  23. #include "src/arm_common/convolution/opr_impl.h"
  24. #endif
  25. #include <cstring>
  26. #include <unordered_map>
  27. MIDOUT_DECL(megdnn_fb_convbwd_float)
  28. using namespace megdnn;
  29. using namespace fallback;
  30. namespace {
  31. template <typename T>
  32. void incr_ptr(T*& dst, ptrdiff_t delta) {
  33. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  34. }
  35. } // namespace
  36. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  37. AlgoFallback algo_fallback;
  38. AlgoNaive algo_naive;
  39. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  40. SmallVector<AlgoBase*> m_all_algos;
  41. AlgoBase::Mapper m_all_algos_map;
  42. public:
  43. AlgoPack() {
  44. static CpuOprDelegationStorage<1> storage;
  45. auto conv_bias_opr = storage.get<ConvBias, 0>();
  46. auto&& conv_bias_algo =
  47. static_cast<ConvBiasImpl*>(conv_bias_opr)->get_all_packed_algo();
  48. for (auto&& algorithm : conv_bias_algo) {
  49. // fallback algo
  50. refhold.emplace_back(new AlgoDefault(algorithm));
  51. m_all_algos.emplace_back(refhold.back().get());
  52. }
  53. m_all_algos.emplace_back(&algo_fallback);
  54. m_all_algos.emplace_back(&algo_naive);
  55. for (auto&& algo : m_all_algos) {
  56. m_all_algos_map.emplace(algo->info().desc, algo);
  57. }
  58. }
  59. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  60. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  61. };
  62. const ConvolutionImpl::AlgoPack& ConvolutionImpl::algo_pack() {
  63. static AlgoPack algo_pack;
  64. return algo_pack;
  65. }
  66. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::get_all_packed_algo() {
  67. return algo_pack().all_algos();
  68. }
  69. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
  70. ConvAlgoTypePack target_type) {
  71. megdnn_assert(nr_type_contain(target_type.data_type),
  72. "ConvBias algo selection only support one type");
  73. SmallVector<ConvolutionImpl::AlgoBase*> algos;
  74. for (auto&& algo : get_all_packed_algo()) {
  75. auto algo_type = algo->get_algo_type();
  76. if (contain_data_type(algo_type.data_type, target_type.data_type) &&
  77. algo_type.algo_category == target_type.algo_category) {
  78. algos.push_back(algo);
  79. }
  80. }
  81. return algos;
  82. }
  83. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  84. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  85. }
  86. #define NCB_ALGO_FUNC(name, algo, param) \
  87. static_cast<AlgoBase*>(algo)->name(param)
  88. void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
  89. _megdnn_tensor_out dst,
  90. const PreprocessedFilter* preprocessed_filter,
  91. _megdnn_workspace workspace) {
  92. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  93. workspace);
  94. auto&& algo = get_algorithm(fparam, workspace.size);
  95. if (!is_naive_algo(algo) &&
  96. NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
  97. exec_with_ncb_kern(fparam, algo);
  98. } else {
  99. naive::ConvolutionForwardImpl::exec(src, filter, dst,
  100. preprocessed_filter, workspace);
  101. }
  102. }
  103. void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout,
  104. _megdnn_tensor_in filter,
  105. const TensorLayout& dst_layout,
  106. PreprocessedFilter* preprocessed_filter,
  107. _megdnn_workspace workspace) {
  108. //! exec_preprocess currently only support preprocess weights before exec,
  109. //! src/dst will be ignored, just set to nullptr
  110. TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
  111. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter,
  112. workspace);
  113. //! should not pass workspace_size limit otherwise can not find match algo
  114. auto&& algo = get_algorithm(fparam);
  115. if (!is_naive_algo(algo) &&
  116. NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <=
  117. workspace.size) {
  118. exec_preprocess_with_ncb_kern(fparam, algo);
  119. } else {
  120. naive::ConvolutionForwardImpl::exec_preprocess(
  121. src_layout, filter, dst_layout, preprocessed_filter, workspace);
  122. }
  123. }
  124. size_t ConvolutionImpl::get_workspace_in_bytes(
  125. const TensorLayout& src, const TensorLayout& filter,
  126. const TensorLayout& dst,
  127. const PreprocessedFilter* preprocessed_filter) {
  128. auto fparam =
  129. make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
  130. auto&& algo = get_algorithm(fparam);
  131. if (is_naive_algo(algo)) {
  132. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  133. src, filter, dst, preprocessed_filter);
  134. } else {
  135. return NCB_ALGO_FUNC(get_workspace, algo, fparam);
  136. }
  137. }
  138. size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
  139. const TensorLayout& src, const TensorLayout& filter,
  140. const TensorLayout& dst) {
  141. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  142. auto&& algo = get_algorithm(fparam);
  143. if (is_naive_algo(algo)) {
  144. return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes(
  145. src, filter, dst);
  146. } else {
  147. return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
  148. }
  149. }
  150. SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
  151. const TensorLayout& src, const TensorLayout& filter,
  152. const TensorLayout& dst) {
  153. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  154. auto&& algo = get_algorithm(fparam);
  155. if (is_naive_algo(algo)) {
  156. return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout(
  157. src, filter, dst);
  158. } else {
  159. return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
  160. }
  161. }
  162. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  163. const TensorLayout& src, const TensorLayout& filter,
  164. const TensorLayout& dst) {
  165. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  166. auto ret = get_all_algorithms_with_ncb(fparam);
  167. if (ret.empty()) {
  168. return naive::ConvolutionForwardImpl::get_all_algorithms(src, filter,
  169. dst);
  170. }
  171. return ret;
  172. }
  173. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  174. const TensorLayout& src, const TensorLayout& filter,
  175. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  176. const AlgoAttribute& positive_attr,
  177. const AlgoAttribute& negative_attr) {
  178. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  179. auto result = get_algorithm_heuristic_with_ncb(
  180. fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
  181. if (result == nullptr) {
  182. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  183. src, filter, dst, workspace_limit_in_bytes, positive_attr,
  184. negative_attr);
  185. }
  186. return result;
  187. }
  188. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  189. const TensorLayout& src, const TensorLayout& filter,
  190. const TensorLayout& dst,
  191. const PreprocessedFilter* preprocessed_filter) {
  192. auto safe_u32 = [](size_t v) -> uint32_t {
  193. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  194. "value too large: %zu", v);
  195. return v;
  196. };
  197. size_t spatial_pos;
  198. if (param().format == Param::Format::NCHW88 ||
  199. param().format == Param::Format::NCHW8 ||
  200. param().format == Param::Format::NCHW4 ||
  201. param().format == Param::Format::NCHW44_DOT ||
  202. param().format == Param::Format::NCHW44) {
  203. spatial_pos = 2;
  204. } else if (param().format == Param::Format::NCHW) {
  205. spatial_pos = 2;
  206. } else if (param().format == Param::Format::NHWC) {
  207. spatial_pos = 1;
  208. } else {
  209. megdnn_assert(0, "invalid conv format %d",
  210. static_cast<int>(param().format));
  211. }
  212. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  213. ->megcore_dispatcher()
  214. ->nr_threads();
  215. return {safe_u32(src[0]),
  216. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  217. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  218. check_layout_fwd(src, filter, dst),
  219. src.dtype,
  220. filter.dtype,
  221. dst.dtype,
  222. src.stride[0],
  223. dst.stride[0],
  224. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  225. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  226. param().compute_mode,
  227. nr_threads,
  228. preprocessed_filter};
  229. }
  230. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  231. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  232. const PreprocessedFilter* preprocessed_filter,
  233. _megdnn_workspace workspace) {
  234. NCBKernParam ret;
  235. static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
  236. src.layout, filter.layout, dst.layout, preprocessed_filter);
  237. ret.src_ptr = src.raw_ptr;
  238. ret.filter_ptr = filter.raw_ptr;
  239. ret.dst_ptr = dst.raw_ptr;
  240. ret.workspace_ptr = workspace.raw_ptr;
  241. ret.workspace_size = workspace.size;
  242. return ret;
  243. }
  244. void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
  245. Algorithm* algo) {
  246. auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
  247. auto&& fallback_handle = handle();
  248. for (auto&& kernel : kerns) {
  249. megdnn_assert(
  250. param.filter_meta.format == Param::Format::NCHW ||
  251. param.filter_meta.format == Param::Format::NHWC ||
  252. param.filter_meta.format == Param::Format::NCHW88 ||
  253. param.filter_meta.format == Param::Format::NCHW44 ||
  254. param.filter_meta.format == Param::Format::NCHW44_DOT,
  255. "invalid conv format");
  256. auto run = [param, kernel](size_t index, size_t thread_id) {
  257. CpuNDRange ndrange_id(kernel.global_size, index);
  258. kernel.kern(param, {thread_id, ndrange_id});
  259. };
  260. static_cast<naive::HandleImpl*>(fallback_handle)
  261. ->dispatch_kern(run, kernel.global_size.total_size());
  262. }
  263. }
  264. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
  265. Algorithm* algo) {
  266. auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
  267. auto&& fallback_handle = handle();
  268. for (auto&& kernel : kerns) {
  269. megdnn_assert(
  270. param.filter_meta.format == Param::Format::NCHW ||
  271. param.filter_meta.format == Param::Format::NHWC ||
  272. param.filter_meta.format == Param::Format::NCHW88 ||
  273. param.filter_meta.format == Param::Format::NCHW44 ||
  274. param.filter_meta.format == Param::Format::NCHW44_DOT,
  275. "invalid conv format");
  276. auto run = [param, kernel](size_t index, size_t thread_id) {
  277. CpuNDRange ndrange_id(kernel.global_size, index);
  278. kernel.kern(param, {thread_id, ndrange_id});
  279. };
  280. static_cast<naive::HandleImpl*>(fallback_handle)
  281. ->dispatch_kern(run, kernel.global_size.total_size());
  282. }
  283. }
  284. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  285. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  286. const AlgoAttribute& positive_attr,
  287. const AlgoAttribute& negative_attr) {
  288. auto algo_data_type = param.deduce_algo_data_type();
  289. auto suggest_category_order = suggest_algo_category_order(param);
  290. for (auto category : suggest_category_order) {
  291. auto&& origin_algos = select_algo_type({algo_data_type, category});
  292. ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
  293. for (auto i : origin_algos) {
  294. bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
  295. param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
  296. negative_attr);
  297. if (usable_attribute &&
  298. static_cast<AlgoBase*>(i)->get_workspace(param) <=
  299. workspace_limit_in_bytes) {
  300. //! store the first usable algo if no prefer algo, choose it as
  301. //! the target algo
  302. if (!heuristic_algo) {
  303. heuristic_algo = i;
  304. }
  305. //! choose the first prefer algo
  306. if (i->is_preferred(param)) {
  307. return i;
  308. }
  309. }
  310. }
  311. if (heuristic_algo) {
  312. return heuristic_algo;
  313. }
  314. }
  315. return nullptr;
  316. }
  317. std::vector<ConvolutionImpl::Algorithm*>
  318. ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  319. std::vector<Algorithm*> ret;
  320. std::vector<Algorithm*> prefer_algos;
  321. for (auto&& i : get_all_packed_algo()) {
  322. if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
  323. if (i->is_preferred(param)) {
  324. prefer_algos.push_back(i);
  325. } else {
  326. ret.push_back(i);
  327. }
  328. }
  329. }
  330. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  331. return ret;
  332. }
  333. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
  334. const AlgorithmDesc& desc) {
  335. if (!desc.valid()) {
  336. return nullptr;
  337. } else {
  338. switch (desc.handle_type) {
  339. case Handle::HandleType::FALLBACK: {
  340. const auto& map = algo_pack().all_algos_map();
  341. megdnn_assert(map.find(desc) != map.end());
  342. return map.at(desc);
  343. }
  344. case Handle::HandleType::NAIVE: {
  345. auto algo = static_cast<naive::HandleImpl*>(handle())
  346. ->default_conv_fwd_algo();
  347. megdnn_assert(algo->info().desc == desc);
  348. return algo;
  349. }
  350. default:
  351. megdnn_throw("Unknown handle type");
  352. return nullptr;
  353. }
  354. }
  355. }
  356. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  357. const NCBKernSizeParam& param, size_t workspace_size) {
  358. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  359. return algo;
  360. }
  361. if (!m_prev_selected_algo ||
  362. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  363. m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
  364. param, workspace_size, AlgoAttribute::DEFAULT,
  365. AlgoAttribute::DEFAULT);
  366. m_prev_selected_algo_sizep = param;
  367. }
  368. return m_prev_selected_algo;
  369. }
  370. SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
  371. const NCBKernSizeParam& param) const {
  372. static CpuOprDelegationStorage<1> storage;
  373. auto conv_bias_opr = storage.get<ConvBias, 0>();
  374. auto conv_bias_param =
  375. ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
  376. return static_cast<ConvBiasImpl*>(conv_bias_opr)
  377. ->suggest_algo_category_order(conv_bias_param);
  378. }
  379. const char* ConvolutionImpl::get_algorithm_set_name() const {
  380. // fallback version 0
  381. return "F0";
  382. }
  383. ConvolutionImpl::AlgoDataType
  384. ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
  385. if (src_type.enumv() == DTypeEnum::Float32) {
  386. return ConvolutionImpl::AlgoDataType::FLOAT32;
  387. #if !MEGDNN_DISABLE_FLOAT16
  388. } else if (src_type.enumv() == DTypeEnum::Float16) {
  389. return ConvolutionImpl::AlgoDataType::FLOAT16;
  390. #endif
  391. } else if (src_type.enumv() == DTypeEnum::Int8 ||
  392. src_type.enumv() == DTypeEnum::QuantizedS8) {
  393. if (dst_type.enumv() == DTypeEnum::Int16) {
  394. return ConvolutionImpl::AlgoDataType::INT8X8X16;
  395. } else {
  396. return ConvolutionImpl::AlgoDataType::QINT8X8X32;
  397. }
  398. } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  399. return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
  400. } else if (src_type.enumv() == DTypeEnum::QuantizedS4) {
  401. return ConvolutionImpl::AlgoDataType::QINT4x4x32;
  402. } else {
  403. megdnn_throw(ssprintf("not support data type of %s * %s -> %s\n",
  404. src_type.name(), filter_type.name(),
  405. dst_type.name()));
  406. }
  407. }
  408. /* ===================== ConvolutionBackwardData ===================== */
  409. class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
  410. AlgoNaive algo_naive;
  411. AlgoDirect algo_direct;
  412. AlgoMatrixMul algo_matmul;
  413. SmallVector<AlgoBase*> m_all_algos;
  414. AlgoBase::Mapper m_all_algos_map;
  415. public:
  416. AlgoPack() {
  417. m_all_algos.emplace_back(&algo_matmul);
  418. m_all_algos.emplace_back(&algo_direct);
  419. m_all_algos.emplace_back(&algo_naive);
  420. for (auto&& algo : m_all_algos) {
  421. m_all_algos_map.emplace(algo->info().desc, algo);
  422. }
  423. }
  424. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  425. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  426. };
  427. const ConvolutionBackwardDataImpl::AlgoPack&
  428. ConvolutionBackwardDataImpl::algo_pack() {
  429. static AlgoPack algo_pack;
  430. return algo_pack;
  431. }
  432. SmallVector<ConvolutionBackwardDataImpl::AlgoBase*>
  433. ConvolutionBackwardDataImpl::get_all_packed_algo() {
  434. return algo_pack().all_algos();
  435. }
  436. void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter,
  437. _megdnn_tensor_in diff,
  438. _megdnn_tensor_out grad,
  439. _megdnn_workspace workspace) {
  440. if (param().format == param::Convolution::Format::NHWCD4 ||
  441. param().format == param::Convolution::Format::NCHW4 ||
  442. (param().format == param::Convolution::Format::NCHW &&
  443. grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  444. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad,
  445. workspace);
  446. }
  447. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  448. return exec_with_ncb_kern(fparam);
  449. }
  450. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  451. const TensorLayout& filter, const TensorLayout& diff,
  452. const TensorLayout& grad) {
  453. if (param().format == param::Convolution::Format::NHWCD4 ||
  454. param().format == param::Convolution::Format::NCHW4 ||
  455. (param().format == param::Convolution::Format::NCHW &&
  456. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  457. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  458. filter, diff, grad);
  459. }
  460. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  461. return get_workspace_with_ncb(fparam);
  462. }
  463. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  464. ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter,
  465. const TensorLayout& diff,
  466. const TensorLayout& grad) {
  467. if (param().format == param::Convolution::Format::NHWCD4 ||
  468. param().format == param::Convolution::Format::NCHW4 ||
  469. (param().format == param::Convolution::Format::NCHW &&
  470. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  471. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  472. filter, diff, grad);
  473. }
  474. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  475. auto ret = get_all_algorithms_with_ncb(fparam);
  476. megdnn_assert(!ret.empty(), "no usable conv fwd algorithm");
  477. return ret;
  478. }
  479. ConvolutionBackwardDataImpl::Algorithm*
  480. ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  481. const TensorLayout& filter, const TensorLayout& diff,
  482. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  483. const AlgoAttribute& positive_attr,
  484. const AlgoAttribute& negative_attr) {
  485. if (param().format == param::Convolution::Format::NHWCD4 ||
  486. param().format == param::Convolution::Format::NCHW4 ||
  487. (param().format == param::Convolution::Format::NCHW &&
  488. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  489. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  490. filter, diff, grad, workspace_limit_in_bytes, positive_attr,
  491. negative_attr);
  492. }
  493. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  494. return get_algorithm_heuristic_with_ncb(fparam, workspace_limit_in_bytes,
  495. positive_attr, negative_attr);
  496. }
  497. ConvolutionBackwardDataImpl::NCBKernSizeParam
  498. ConvolutionBackwardDataImpl::make_ncb_kern_size_param(
  499. const TensorLayout& filter, const TensorLayout& diff,
  500. const TensorLayout& grad) {
  501. auto safe_u32 = [](size_t v) -> uint32_t {
  502. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  503. "value too large: %zu", v);
  504. return v;
  505. };
  506. size_t spatial_pos;
  507. if (param().format == Param::Format::NCHW) {
  508. spatial_pos = 2;
  509. } else {
  510. megdnn_assert(param().format == Param::Format::NHWC,
  511. "invalid conv format");
  512. spatial_pos = 1;
  513. }
  514. auto grad_fwd = grad;
  515. auto filter_fwd = filter;
  516. auto diff_fwd = diff;
  517. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  518. return {
  519. safe_u32(diff[0]),
  520. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  521. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  522. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  523. diff.dtype,
  524. filter.dtype,
  525. grad.dtype,
  526. diff,
  527. filter,
  528. grad,
  529. diff.stride[0],
  530. grad.stride[0],
  531. 0,
  532. 0,
  533. 0,
  534. param().compute_mode,
  535. };
  536. }
  537. ConvolutionBackwardDataImpl::NCBKernParam
  538. ConvolutionBackwardDataImpl::make_ncb_kern_param(_megdnn_tensor_in filter,
  539. _megdnn_tensor_in diff,
  540. _megdnn_tensor_out grad,
  541. _megdnn_workspace workspace) {
  542. NCBKernParam ret;
  543. static_cast<NCBKernSizeParam&>(ret) =
  544. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  545. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  546. megdnn_assert(workspace.size >= required_workspace_in_bytes,
  547. "required workspace: %zu; provided workspace: %zu",
  548. required_workspace_in_bytes, workspace.size);
  549. ret.filter_ptr = filter.raw_ptr;
  550. ret.diff_ptr = diff.raw_ptr;
  551. ret.grad_ptr = grad.raw_ptr;
  552. ret.workspace_ptr = workspace.raw_ptr;
  553. ret.workspace_size = workspace.size;
  554. return ret;
  555. }
  556. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(
  557. const NCBKernParam& param) {
  558. auto p1g = param;
  559. auto group = p1g.filter_meta.group;
  560. p1g.filter_meta.group = 1;
  561. auto&& algo = get_algorithm(p1g);
  562. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  563. if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
  564. auto run = [kptr, param]() { kptr(param); };
  565. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  566. } else {
  567. megdnn_assert(p1g.filter_meta.format == Param::Format::NCHW ||
  568. p1g.filter_meta.format == Param::Format::NHWC,
  569. "invalid conv format");
  570. auto run = [kptr, p1g_orig = p1g, group]() {
  571. auto p1g = p1g_orig;
  572. ptrdiff_t istrd, fstrd, ostrd;
  573. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  574. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  575. p1g.filter_type.size();
  576. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  577. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  578. p1g.diff_extra_mem_size =
  579. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  580. p1g.filter_extra_mem_size =
  581. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  582. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  583. p1g.filter_type.size();
  584. p1g.grad_extra_mem_size =
  585. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  586. if (p1g.filter_meta.format == Param::Format::NCHW) {
  587. istrd *= p1g.isz[0] * p1g.isz[1];
  588. ostrd *= p1g.osz[0] * p1g.osz[1];
  589. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  590. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  591. } else {
  592. // must be NHWC. No action performed.
  593. }
  594. for (size_t i = 0; i < group; ++i) {
  595. kptr(p1g);
  596. incr_ptr(p1g.diff_ptr, istrd);
  597. incr_ptr(p1g.filter_ptr, fstrd);
  598. incr_ptr(p1g.grad_ptr, ostrd);
  599. p1g.diff_extra_mem_size -= istrd;
  600. p1g.filter_extra_mem_size -= fstrd;
  601. p1g.grad_extra_mem_size -= ostrd;
  602. }
  603. };
  604. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  605. }
  606. }
  607. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  608. const NCBKernSizeParam& param) {
  609. if (param.filter_meta.group != 1) {
  610. auto p1g = param;
  611. p1g.filter_meta.group = 1;
  612. auto algo = get_algorithm(p1g);
  613. return ncb_1g_get_workspace(algo, p1g);
  614. }
  615. auto algo = get_algorithm(param);
  616. return ncb_1g_get_workspace(algo, param);
  617. }
  618. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  619. ConvolutionBackwardDataImpl::get_all_algorithms_with_ncb(
  620. const NCBKernSizeParam& param) {
  621. if (param.filter_meta.group != 1) {
  622. auto p1g = param;
  623. p1g.filter_meta.group = 1;
  624. return ncb_1g_get_all_algorithms(p1g);
  625. }
  626. return ncb_1g_get_all_algorithms(param);
  627. }
  628. ConvolutionBackwardDataImpl::Algorithm*
  629. ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
  630. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  631. const AlgoAttribute& positive_attr,
  632. const AlgoAttribute& negative_attr) {
  633. if (param.filter_meta.group != 1) {
  634. auto p1g = param;
  635. p1g.filter_meta.group = 1;
  636. return ncb_1g_get_algorithm_heuristic(p1g, workspace_limit_in_bytes,
  637. positive_attr, negative_attr);
  638. }
  639. return ncb_1g_get_algorithm_heuristic(param, workspace_limit_in_bytes,
  640. positive_attr, negative_attr);
  641. }
  642. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  643. Algorithm* algo, const NCBKernSizeParam& param) {
  644. megdnn_assert(param.filter_meta.group == 1);
  645. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  646. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  647. }
  648. return 0;
  649. }
  650. ConvolutionBackwardDataImpl::ncb_kern_t
  651. ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
  652. Algorithm* algo, const NCBKernSizeParam& param) {
  653. megdnn_assert(param.filter_meta.group == 1);
  654. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  655. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  656. }
  657. megdnn_throw("no suitable ConvolutionBackwardData algorithm");
  658. }
  659. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  660. const NCBKernSizeParam& param) {
  661. auto&& fm = param.filter_meta;
  662. auto OC = fm.ocpg, IC = fm.icpg;
  663. return (OC * IC >= 32) ||
  664. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  665. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  666. }
  667. std::vector<ConvolutionBackwardDataImpl::Algorithm*>
  668. ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
  669. const NCBKernSizeParam& param) {
  670. std::vector<Algorithm*> ret;
  671. std::vector<Algorithm*> prefer_algos;
  672. for (auto&& i : get_all_packed_algo()) {
  673. if (i->usable(this, param)) {
  674. if (i->is_preferred(param)) {
  675. prefer_algos.push_back(i);
  676. } else {
  677. ret.push_back(i);
  678. }
  679. }
  680. }
  681. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  682. return ret;
  683. }
  684. ConvolutionBackwardDataImpl::Algorithm*
  685. ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic(
  686. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  687. const AlgoAttribute& positive_attr,
  688. const AlgoAttribute& negative_attr) {
  689. for (auto i : ncb_1g_get_all_algorithms(param)) {
  690. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  691. if (i->contain_attribute_all(positive_attr) &&
  692. !i->contain_attribute_any(negative_attr)) {
  693. return i;
  694. }
  695. }
  696. }
  697. megdnn_assert(0,
  698. "no suitable algorithm found within given workspace limit");
  699. }
  700. ConvolutionBackwardDataImpl::Algorithm*
  701. ConvolutionBackwardDataImpl::get_algorithm_from_desc(
  702. const AlgorithmDesc& desc) {
  703. if (!desc.valid()) {
  704. return nullptr;
  705. } else {
  706. switch (desc.handle_type) {
  707. case Handle::HandleType::FALLBACK: {
  708. const auto& map = algo_pack().all_algos_map();
  709. megdnn_assert(map.find(desc) != map.end());
  710. return map.at(desc);
  711. }
  712. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  713. case Handle::HandleType::ARM_COMMON:
  714. case Handle::HandleType::AARCH64:
  715. case Handle::HandleType::ARMV7:
  716. return arm_common::ConvolutionBackwardDataImpl::
  717. get_algo_from_desc(desc);
  718. #endif
  719. case Handle::HandleType::NAIVE: {
  720. auto algo = static_cast<naive::HandleImpl*>(handle())
  721. ->default_conv_bwd_data_algo();
  722. megdnn_assert(algo->info().desc == desc);
  723. return algo;
  724. }
  725. default:
  726. megdnn_throw("Unknown handle type");
  727. return nullptr;
  728. }
  729. }
  730. }
  731. ConvolutionBackwardDataImpl::Algorithm*
  732. ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) {
  733. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  734. return algo;
  735. }
  736. if (!m_prev_selected_algo ||
  737. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  738. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  739. param, std::numeric_limits<size_t>::max(),
  740. AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  741. m_prev_selected_algo_sizep = param;
  742. }
  743. return m_prev_selected_algo;
  744. }
  745. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  746. // fallback version 0
  747. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  748. }
  749. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台