You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816
  1. #include "src/fallback/convolution/opr_impl.h"
  2. #include "src/common/algo_chooser.h"
  3. #include "src/common/metahelper.h"
  4. #include "src/common/opr_delegate.h"
  5. #include "src/common/utils.h"
  6. #include "src/fallback/convolution/algos.h"
  7. #include "src/fallback/convolution/run_conv.h"
  8. #include "src/naive/convolution/helper.h"
  9. #include "src/naive/handle.h"
  10. #include "midout.h"
  11. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  12. #include "src/arm_common/convolution/opr_impl.h"
  13. #endif
  14. #include <cstring>
  15. #include <unordered_map>
  16. MIDOUT_DECL(megdnn_fb_convbwd_float)
  17. using namespace megdnn;
  18. using namespace fallback;
  19. namespace {
  20. template <typename T>
  21. void incr_ptr(T*& dst, ptrdiff_t delta) {
  22. dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
  23. }
  24. } // namespace
  25. class ConvolutionImpl::AlgoPack : NonCopyableObj {
  26. AlgoFallback algo_fallback;
  27. AlgoNaive algo_naive;
  28. SmallVector<std::unique_ptr<AlgoBase>> refhold;
  29. SmallVector<AlgoBase*> m_all_algos;
  30. AlgoBase::Mapper m_all_algos_map;
  31. public:
  32. AlgoPack() {
  33. static CpuOprDelegationStorage<1> storage;
  34. auto conv_bias_opr = storage.get<ConvBias, 0>();
  35. auto&& conv_bias_algo =
  36. static_cast<ConvBiasImpl*>(conv_bias_opr)->get_all_packed_algo();
  37. for (auto&& algorithm : conv_bias_algo) {
  38. // fallback algo
  39. refhold.emplace_back(new AlgoDefault(algorithm));
  40. m_all_algos.emplace_back(refhold.back().get());
  41. }
  42. m_all_algos.emplace_back(&algo_fallback);
  43. m_all_algos.emplace_back(&algo_naive);
  44. for (auto&& algo : m_all_algos) {
  45. m_all_algos_map.emplace(algo->info().desc, algo);
  46. }
  47. }
  48. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  49. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  50. };
  51. const ConvolutionImpl::AlgoPack& ConvolutionImpl::algo_pack() {
  52. static AlgoPack algo_pack;
  53. return algo_pack;
  54. }
  55. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::get_all_packed_algo() {
  56. return algo_pack().all_algos();
  57. }
  58. SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
  59. ConvAlgoTypePack target_type) {
  60. megdnn_assert(
  61. nr_type_contain(target_type.data_type),
  62. "ConvBias algo selection only support one type");
  63. SmallVector<ConvolutionImpl::AlgoBase*> algos;
  64. for (auto&& algo : get_all_packed_algo()) {
  65. auto algo_type = algo->get_algo_type();
  66. if (contain_data_type(algo_type.data_type, target_type.data_type) &&
  67. algo_type.algo_category == target_type.algo_category) {
  68. algos.push_back(algo);
  69. }
  70. }
  71. return algos;
  72. }
  73. bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
  74. return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
  75. }
  76. #define NCB_ALGO_FUNC(name, algo, param) static_cast<AlgoBase*>(algo)->name(param)
  77. void ConvolutionImpl::exec(
  78. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  79. const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
  80. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter, workspace);
  81. auto&& algo = get_algorithm(fparam, workspace.size);
  82. if (!is_naive_algo(algo) &&
  83. NCB_ALGO_FUNC(get_workspace, algo, fparam) <= workspace.size) {
  84. exec_with_ncb_kern(fparam, algo);
  85. } else {
  86. naive::ConvolutionForwardImpl::exec(
  87. src, filter, dst, preprocessed_filter, workspace);
  88. }
  89. }
  90. void ConvolutionImpl::exec_preprocess(
  91. const TensorLayout& src_layout, _megdnn_tensor_in filter,
  92. const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter,
  93. _megdnn_workspace workspace) {
  94. //! exec_preprocess currently only support preprocess weights before exec,
  95. //! src/dst will be ignored, just set to nullptr
  96. TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
  97. auto fparam = make_ncb_kern_param(src, filter, dst, preprocessed_filter, workspace);
  98. //! should not pass workspace_size limit otherwise can not find match algo
  99. auto&& algo = get_algorithm(fparam);
  100. if (!is_naive_algo(algo) &&
  101. NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam) <= workspace.size) {
  102. exec_preprocess_with_ncb_kern(fparam, algo);
  103. } else {
  104. naive::ConvolutionForwardImpl::exec_preprocess(
  105. src_layout, filter, dst_layout, preprocessed_filter, workspace);
  106. }
  107. }
  108. size_t ConvolutionImpl::get_workspace_in_bytes(
  109. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  110. const PreprocessedFilter* preprocessed_filter) {
  111. TensorLayoutArray layouts{src, filter, dst};
  112. AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
  113. layouts.data(), layouts.size(),
  114. &this->param(), sizeof(this->param())};
  115. auto rst = AlgorithmCache::instance().get(key);
  116. if (rst.policy.algo.valid()) {
  117. return rst.workspace;
  118. }
  119. auto fparam = make_ncb_kern_size_param(src, filter, dst, preprocessed_filter);
  120. auto&& algo = get_algorithm(fparam);
  121. if (is_naive_algo(algo)) {
  122. return naive::ConvolutionForwardImpl::get_workspace_in_bytes(
  123. src, filter, dst, preprocessed_filter);
  124. } else {
  125. return NCB_ALGO_FUNC(get_workspace, algo, fparam);
  126. }
  127. }
  128. size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
  129. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
  130. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  131. auto&& algo = get_algorithm(fparam);
  132. if (is_naive_algo(algo)) {
  133. return naive::ConvolutionForwardImpl::get_preprocess_workspace_in_bytes(
  134. src, filter, dst);
  135. } else {
  136. return NCB_ALGO_FUNC(get_preprocess_workspace, algo, fparam);
  137. }
  138. }
  139. SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
  140. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
  141. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  142. auto&& algo = get_algorithm(fparam);
  143. if (is_naive_algo(algo)) {
  144. return naive::ConvolutionForwardImpl::deduce_preprocessed_filter_layout(
  145. src, filter, dst);
  146. } else {
  147. return NCB_ALGO_FUNC(deduce_preprocessed_filter_layout, algo, fparam);
  148. }
  149. }
  150. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms(
  151. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
  152. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  153. auto ret = get_all_algorithms_with_ncb(fparam);
  154. if (ret.empty()) {
  155. return naive::ConvolutionForwardImpl::get_all_algorithms_safe(src, filter, dst);
  156. }
  157. return ret;
  158. }
  159. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_safe(
  160. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) {
  161. auto ret_safe = ConvolutionImpl::get_all_algorithms(src, filter, dst);
  162. return ret_safe;
  163. }
  164. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic(
  165. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  166. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  167. const AlgoAttribute& negative_attr) {
  168. auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
  169. auto result = get_algorithm_heuristic_with_ncb(
  170. fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
  171. if (result == nullptr) {
  172. result = naive::ConvolutionForwardImpl::get_algorithm_heuristic(
  173. src, filter, dst, workspace_limit_in_bytes, positive_attr,
  174. negative_attr);
  175. }
  176. return result;
  177. }
  178. ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
  179. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  180. const PreprocessedFilter* preprocessed_filter) {
  181. auto safe_u32 = [](size_t v) -> uint32_t {
  182. megdnn_assert(
  183. v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
  184. return v;
  185. };
  186. size_t spatial_pos;
  187. if (param().format == Param::Format::NCHW88 ||
  188. param().format == Param::Format::NCHW8 ||
  189. param().format == Param::Format::NCHW4 ||
  190. param().format == Param::Format::NCHW44_DOT ||
  191. param().format == Param::Format::NCHW44) {
  192. spatial_pos = 2;
  193. } else if (param().format == Param::Format::NCHW) {
  194. spatial_pos = 2;
  195. } else if (param().format == Param::Format::NHWC) {
  196. spatial_pos = 1;
  197. } else {
  198. megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format));
  199. }
  200. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  201. ->megcore_dispatcher()
  202. ->nr_threads();
  203. return {safe_u32(src[0]),
  204. {{safe_u32(src[spatial_pos]), safe_u32(src[spatial_pos + 1])}},
  205. {{safe_u32(dst[spatial_pos]), safe_u32(dst[spatial_pos + 1])}},
  206. check_layout_fwd(src, filter, dst),
  207. src.dtype,
  208. filter.dtype,
  209. dst.dtype,
  210. src.stride[0],
  211. dst.stride[0],
  212. {src.stride[0], src.stride[1], src.stride[2], src.stride[3]},
  213. {dst.stride[0], dst.stride[1], dst.stride[2], dst.stride[3]},
  214. param().compute_mode,
  215. nr_threads,
  216. preprocessed_filter,
  217. handle()};
  218. }
  219. ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
  220. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  221. const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
  222. NCBKernParam ret;
  223. static_cast<NCBKernSizeParam&>(ret) = make_ncb_kern_size_param(
  224. src.layout, filter.layout, dst.layout, preprocessed_filter);
  225. ret.src_ptr = src.get_ref_ptr();
  226. ret.filter_ptr = filter.get_ref_ptr();
  227. ret.dst_ptr = dst.get_ref_ptr();
  228. ret.workspace_ptr = workspace.raw_ptr;
  229. ret.workspace_size = workspace.size;
  230. return ret;
  231. }
  232. void ConvolutionImpl::exec_preprocess_with_ncb_kern(
  233. const NCBKernParam& param, Algorithm* algo) {
  234. auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
  235. auto&& fallback_handle = handle();
  236. for (auto&& kernel : kerns) {
  237. megdnn_assert(
  238. param.filter_meta.format == Param::Format::NCHW ||
  239. param.filter_meta.format == Param::Format::NHWC ||
  240. param.filter_meta.format == Param::Format::NCHW88 ||
  241. param.filter_meta.format == Param::Format::NCHW44 ||
  242. param.filter_meta.format == Param::Format::NCHW44_DOT,
  243. "invalid conv format");
  244. auto run = [param, kernel](size_t index, size_t thread_id) {
  245. CpuNDRange ndrange_id(kernel.global_size, index);
  246. kernel.kern(param, {thread_id, ndrange_id});
  247. };
  248. static_cast<naive::HandleImpl*>(fallback_handle)
  249. ->dispatch_kern(run, kernel.global_size.total_size());
  250. }
  251. }
  252. void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo) {
  253. auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
  254. auto&& fallback_handle = handle();
  255. for (auto&& kernel : kerns) {
  256. megdnn_assert(
  257. param.filter_meta.format == Param::Format::NCHW ||
  258. param.filter_meta.format == Param::Format::NHWC ||
  259. param.filter_meta.format == Param::Format::NCHW88 ||
  260. param.filter_meta.format == Param::Format::NCHW44 ||
  261. param.filter_meta.format == Param::Format::NCHW44_DOT,
  262. "invalid conv format");
  263. auto run = [param, kernel](size_t index, size_t thread_id) {
  264. CpuNDRange ndrange_id(kernel.global_size, index);
  265. kernel.kern(param, {thread_id, ndrange_id});
  266. };
  267. static_cast<naive::HandleImpl*>(fallback_handle)
  268. ->dispatch_kern(run, kernel.global_size.total_size());
  269. }
  270. }
  271. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
  272. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  273. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
  274. auto algo_data_type = param.deduce_algo_data_type();
  275. auto suggest_category_order = suggest_algo_category_order(param);
  276. for (auto category : suggest_category_order) {
  277. auto&& origin_algos = select_algo_type({algo_data_type, category});
  278. ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
  279. for (auto i : origin_algos) {
  280. bool usable_attribute = static_cast<AlgoBase*>(i)->usable_attribute(
  281. param, AlgoSelectionStrategy::HEURISTIC, positive_attr,
  282. negative_attr);
  283. if (usable_attribute && static_cast<AlgoBase*>(i)->get_workspace(param) <=
  284. workspace_limit_in_bytes) {
  285. //! store the first usable algo if no prefer algo, choose it as
  286. //! the target algo
  287. if (!heuristic_algo) {
  288. heuristic_algo = i;
  289. }
  290. //! choose the first prefer algo
  291. if (i->is_preferred(param)) {
  292. return i;
  293. }
  294. }
  295. }
  296. if (heuristic_algo) {
  297. return heuristic_algo;
  298. }
  299. }
  300. return nullptr;
  301. }
  302. std::vector<ConvolutionImpl::Algorithm*> ConvolutionImpl::get_all_algorithms_with_ncb(
  303. const NCBKernSizeParam& param) {
  304. std::vector<Algorithm*> ret;
  305. std::vector<Algorithm*> prefer_algos;
  306. for (auto&& i : get_all_packed_algo()) {
  307. if (i->usable(param, AlgoSelectionStrategy::FULL_RUN)) {
  308. if (i->is_preferred(param)) {
  309. prefer_algos.push_back(i);
  310. } else {
  311. ret.push_back(i);
  312. }
  313. }
  314. }
  315. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  316. return ret;
  317. }
  318. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc(
  319. const AlgorithmDesc& desc) {
  320. if (!desc.valid()) {
  321. return nullptr;
  322. } else {
  323. switch (desc.handle_type) {
  324. case Handle::HandleType::FALLBACK: {
  325. const auto& map = algo_pack().all_algos_map();
  326. megdnn_assert(map.find(desc) != map.end());
  327. return map.at(desc);
  328. }
  329. case Handle::HandleType::NAIVE: {
  330. auto algo = static_cast<naive::HandleImpl*>(handle())
  331. ->default_conv_fwd_algo();
  332. megdnn_assert(algo->info().desc == desc);
  333. return algo;
  334. }
  335. default:
  336. megdnn_throw("Unknown handle type");
  337. return nullptr;
  338. }
  339. }
  340. }
  341. ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
  342. const NCBKernSizeParam& param, size_t workspace_size) {
  343. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  344. return algo;
  345. }
  346. if (!m_prev_selected_algo ||
  347. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  348. m_prev_selected_algo = get_algorithm_heuristic_with_ncb(
  349. param, workspace_size, AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  350. m_prev_selected_algo_sizep = param;
  351. }
  352. return m_prev_selected_algo;
  353. }
  354. SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
  355. const NCBKernSizeParam& param) const {
  356. static CpuOprDelegationStorage<1> storage;
  357. auto conv_bias_opr = storage.get<ConvBias, 0>();
  358. auto conv_bias_param = ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
  359. return static_cast<ConvBiasImpl*>(conv_bias_opr)
  360. ->suggest_algo_category_order(conv_bias_param);
  361. }
  362. const char* ConvolutionImpl::get_algorithm_set_name() const {
  363. // fallback version 0
  364. return "F0";
  365. }
  366. ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type()
  367. const {
  368. if (src_type.enumv() == DTypeEnum::Float32) {
  369. return ConvolutionImpl::AlgoDataType::FLOAT32;
  370. #if !MEGDNN_DISABLE_FLOAT16
  371. } else if (src_type.enumv() == DTypeEnum::Float16) {
  372. return ConvolutionImpl::AlgoDataType::FLOAT16;
  373. #endif
  374. } else if (
  375. src_type.enumv() == DTypeEnum::Int8 ||
  376. src_type.enumv() == DTypeEnum::QuantizedS8) {
  377. if (dst_type.enumv() == DTypeEnum::Int16) {
  378. return ConvolutionImpl::AlgoDataType::INT8X8X16;
  379. } else {
  380. return ConvolutionImpl::AlgoDataType::QINT8X8X32;
  381. }
  382. } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
  383. return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
  384. } else if (
  385. src_type.enumv() == DTypeEnum::QuantizedS4 ||
  386. src_type.enumv() == DTypeEnum::Quantized4Asymm) {
  387. return ConvolutionImpl::AlgoDataType::QINT4x4x32;
  388. } else {
  389. megdnn_throw(ssprintf(
  390. "not support data type of %s * %s -> %s\n", src_type.name(),
  391. filter_type.name(), dst_type.name()));
  392. }
  393. }
  394. /* ===================== ConvolutionBackwardData ===================== */
  395. class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
  396. AlgoNaive algo_naive;
  397. AlgoDirect algo_direct;
  398. AlgoMatrixMul algo_matmul;
  399. SmallVector<AlgoBase*> m_all_algos;
  400. AlgoBase::Mapper m_all_algos_map;
  401. public:
  402. AlgoPack() {
  403. m_all_algos.emplace_back(&algo_matmul);
  404. m_all_algos.emplace_back(&algo_direct);
  405. m_all_algos.emplace_back(&algo_naive);
  406. for (auto&& algo : m_all_algos) {
  407. m_all_algos_map.emplace(algo->info().desc, algo);
  408. }
  409. }
  410. const SmallVector<AlgoBase*>& all_algos() const { return m_all_algos; }
  411. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  412. };
  413. const ConvolutionBackwardDataImpl::AlgoPack& ConvolutionBackwardDataImpl::algo_pack() {
  414. static AlgoPack algo_pack;
  415. return algo_pack;
  416. }
  417. SmallVector<ConvolutionBackwardDataImpl::AlgoBase*> ConvolutionBackwardDataImpl::
  418. get_all_packed_algo() {
  419. return algo_pack().all_algos();
  420. }
  421. void ConvolutionBackwardDataImpl::exec(
  422. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  423. _megdnn_workspace workspace) {
  424. if (param().format == param::Convolution::Format::NHWCD4 ||
  425. param().format == param::Convolution::Format::NCHW4 ||
  426. ((param().format == param::Convolution::Format::NCHW ||
  427. param().format == param::Convolution::Format::NHWC) &&
  428. grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  429. return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace);
  430. }
  431. auto fparam = make_ncb_kern_param(filter, diff, grad, workspace);
  432. return exec_with_ncb_kern(fparam);
  433. }
  434. size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  435. const TensorLayout& filter, const TensorLayout& diff,
  436. const TensorLayout& grad) {
  437. TensorLayoutArray layouts{filter, diff, grad};
  438. AlgorithmCache::Key key{this->handle(), this->get_opr_type(),
  439. layouts.data(), layouts.size(),
  440. &this->param(), sizeof(this->param())};
  441. auto rst = AlgorithmCache::instance().get(key);
  442. if (rst.policy.algo.valid()) {
  443. return rst.workspace;
  444. }
  445. if (param().format == param::Convolution::Format::NHWCD4 ||
  446. param().format == param::Convolution::Format::NCHW4 ||
  447. ((param().format == param::Convolution::Format::NCHW ||
  448. param().format == param::Convolution::Format::NHWC) &&
  449. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  450. return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
  451. filter, diff, grad);
  452. }
  453. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  454. return get_workspace_with_ncb(fparam);
  455. }
  456. std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
  457. get_all_algorithms(
  458. const TensorLayout& filter, const TensorLayout& diff,
  459. const TensorLayout& grad) {
  460. if (param().format == param::Convolution::Format::NHWCD4 ||
  461. param().format == param::Convolution::Format::NCHW4 ||
  462. ((param().format == param::Convolution::Format::NCHW ||
  463. param().format == param::Convolution::Format::NHWC) &&
  464. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  465. return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
  466. filter, diff, grad);
  467. }
  468. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  469. auto ret = get_all_algorithms_with_ncb(fparam);
  470. return ret;
  471. }
  472. std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
  473. get_all_algorithms_safe(
  474. const TensorLayout& filter, const TensorLayout& diff,
  475. const TensorLayout& grad) {
  476. auto ret_safe = ConvolutionBackwardDataImpl::get_all_algorithms(filter, diff, grad);
  477. megdnn_assert(!ret_safe.empty(), "no usable conv bwd algorithm");
  478. return ret_safe;
  479. }
  480. ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
  481. get_algorithm_heuristic(
  482. const TensorLayout& filter, const TensorLayout& diff,
  483. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  484. const AlgoAttribute& positive_attr,
  485. const AlgoAttribute& negative_attr) {
  486. if (param().format == param::Convolution::Format::NHWCD4 ||
  487. param().format == param::Convolution::Format::NCHW4 ||
  488. ((param().format == param::Convolution::Format::NCHW ||
  489. param().format == param::Convolution::Format::NHWC) &&
  490. grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
  491. return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
  492. filter, diff, grad, workspace_limit_in_bytes, positive_attr,
  493. negative_attr);
  494. }
  495. auto fparam = make_ncb_kern_size_param(filter, diff, grad);
  496. return get_algorithm_heuristic_with_ncb(
  497. fparam, workspace_limit_in_bytes, positive_attr, negative_attr);
  498. }
  499. ConvolutionBackwardDataImpl::NCBKernSizeParam ConvolutionBackwardDataImpl::
  500. make_ncb_kern_size_param(
  501. const TensorLayout& filter, const TensorLayout& diff,
  502. const TensorLayout& grad) {
  503. auto safe_u32 = [](size_t v) -> uint32_t {
  504. megdnn_assert(
  505. v <= std::numeric_limits<uint32_t>::max(), "value too large: %zu", v);
  506. return v;
  507. };
  508. size_t spatial_pos;
  509. if (param().format == Param::Format::NCHW) {
  510. spatial_pos = 2;
  511. } else {
  512. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  513. spatial_pos = 1;
  514. }
  515. auto grad_fwd = grad;
  516. auto filter_fwd = filter;
  517. auto diff_fwd = diff;
  518. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  519. return {
  520. safe_u32(diff[0]),
  521. {{safe_u32(diff[spatial_pos]), safe_u32(diff[spatial_pos + 1])}},
  522. {{safe_u32(grad[spatial_pos]), safe_u32(grad[spatial_pos + 1])}},
  523. check_layout_fwd(grad_fwd, filter_fwd, diff_fwd),
  524. diff.dtype,
  525. filter.dtype,
  526. grad.dtype,
  527. diff,
  528. filter,
  529. grad,
  530. diff.stride[0],
  531. grad.stride[0],
  532. 0,
  533. 0,
  534. 0,
  535. param().compute_mode,
  536. };
  537. }
  538. ConvolutionBackwardDataImpl::NCBKernParam ConvolutionBackwardDataImpl::
  539. make_ncb_kern_param(
  540. _megdnn_tensor_in filter, _megdnn_tensor_in diff,
  541. _megdnn_tensor_out grad, _megdnn_workspace workspace) {
  542. NCBKernParam ret;
  543. static_cast<NCBKernSizeParam&>(ret) =
  544. make_ncb_kern_size_param(filter.layout, diff.layout, grad.layout);
  545. auto required_workspace_in_bytes = get_workspace_with_ncb(ret);
  546. megdnn_assert(
  547. workspace.size >= required_workspace_in_bytes,
  548. "required workspace: %zu; provided workspace: %zu",
  549. required_workspace_in_bytes, workspace.size);
  550. ret.filter_ptr = filter.get_ref_ptr();
  551. ret.diff_ptr = diff.get_ref_ptr();
  552. ret.grad_ptr = grad.get_ref_ptr();
  553. ret.workspace_ptr = workspace.raw_ptr;
  554. ret.workspace_size = workspace.size;
  555. return ret;
  556. }
  557. void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) {
  558. auto p1g = param;
  559. auto group = p1g.filter_meta.group;
  560. p1g.filter_meta.group = 1;
  561. auto&& algo = get_algorithm(p1g);
  562. auto kptr = ncb_1g_dispatch_kern(algo, p1g);
  563. if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) {
  564. auto run = [kptr, param]() { kptr(param); };
  565. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  566. } else {
  567. megdnn_assert(
  568. p1g.filter_meta.format == Param::Format::NCHW ||
  569. p1g.filter_meta.format == Param::Format::NHWC,
  570. "invalid conv format");
  571. auto run = [kptr, p1g_orig = p1g, group]() {
  572. auto p1g = p1g_orig;
  573. ptrdiff_t istrd, fstrd, ostrd;
  574. fstrd = p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  575. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  576. p1g.filter_type.size();
  577. istrd = p1g.filter_meta.ocpg * p1g.diff_type.size();
  578. ostrd = p1g.filter_meta.icpg * p1g.grad_type.size();
  579. p1g.diff_extra_mem_size =
  580. (group - 1) * p1g.filter_meta.ocpg * p1g.diff_type.size();
  581. p1g.filter_extra_mem_size =
  582. (group - 1) * p1g.filter_meta.icpg * p1g.filter_meta.ocpg *
  583. p1g.filter_meta.spatial[0] * p1g.filter_meta.spatial[1] *
  584. p1g.filter_type.size();
  585. p1g.grad_extra_mem_size =
  586. (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
  587. if (p1g.filter_meta.format == Param::Format::NCHW) {
  588. istrd *= p1g.isz[0] * p1g.isz[1];
  589. ostrd *= p1g.osz[0] * p1g.osz[1];
  590. p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];
  591. p1g.grad_extra_mem_size *= p1g.osz[0] * p1g.osz[1];
  592. } else {
  593. // must be NHWC. No action performed.
  594. }
  595. for (size_t i = 0; i < group; ++i) {
  596. kptr(p1g);
  597. p1g.diff_ptr += istrd;
  598. p1g.filter_ptr += fstrd;
  599. p1g.grad_ptr += ostrd;
  600. p1g.diff_extra_mem_size -= istrd;
  601. p1g.filter_extra_mem_size -= fstrd;
  602. p1g.grad_extra_mem_size -= ostrd;
  603. }
  604. };
  605. static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run);
  606. }
  607. }
  608. size_t ConvolutionBackwardDataImpl::get_workspace_with_ncb(
  609. const NCBKernSizeParam& param) {
  610. if (param.filter_meta.group != 1) {
  611. auto p1g = param;
  612. p1g.filter_meta.group = 1;
  613. auto algo = get_algorithm(p1g);
  614. return ncb_1g_get_workspace(algo, p1g);
  615. }
  616. auto algo = get_algorithm(param);
  617. return ncb_1g_get_workspace(algo, param);
  618. }
  619. std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
  620. get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
  621. if (param.filter_meta.group != 1) {
  622. auto p1g = param;
  623. p1g.filter_meta.group = 1;
  624. return ncb_1g_get_all_algorithms(p1g);
  625. }
  626. return ncb_1g_get_all_algorithms(param);
  627. }
  628. ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
  629. get_algorithm_heuristic_with_ncb(
  630. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  631. const AlgoAttribute& positive_attr,
  632. const AlgoAttribute& negative_attr) {
  633. if (param.filter_meta.group != 1) {
  634. auto p1g = param;
  635. p1g.filter_meta.group = 1;
  636. return ncb_1g_get_algorithm_heuristic(
  637. p1g, workspace_limit_in_bytes, positive_attr, negative_attr);
  638. }
  639. return ncb_1g_get_algorithm_heuristic(
  640. param, workspace_limit_in_bytes, positive_attr, negative_attr);
  641. }
  642. size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
  643. Algorithm* algo, const NCBKernSizeParam& param) {
  644. megdnn_assert(param.filter_meta.group == 1);
  645. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  646. return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
  647. }
  648. return 0;
  649. }
  650. ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::
  651. ncb_1g_dispatch_kern(Algorithm* algo, const NCBKernSizeParam& param) {
  652. megdnn_assert(param.filter_meta.group == 1);
  653. if (algo->handle_type() == Handle::HandleType::FALLBACK) {
  654. return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
  655. }
  656. megdnn_throw("no suitable ConvolutionBackwardData algorithm");
  657. }
  658. bool ConvolutionBackwardDataImpl::is_matrix_mul_preferred(
  659. const NCBKernSizeParam& param) {
  660. auto&& fm = param.filter_meta;
  661. auto OC = fm.ocpg, IC = fm.icpg;
  662. return (OC * IC >= 32) ||
  663. (fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 &&
  664. fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1);
  665. }
  666. std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl::
  667. ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) {
  668. std::vector<Algorithm*> ret;
  669. std::vector<Algorithm*> prefer_algos;
  670. for (auto&& i : get_all_packed_algo()) {
  671. if (i->usable(this, param)) {
  672. if (i->is_preferred(param)) {
  673. prefer_algos.push_back(i);
  674. } else {
  675. ret.push_back(i);
  676. }
  677. }
  678. }
  679. ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
  680. return ret;
  681. }
  682. ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
  683. ncb_1g_get_algorithm_heuristic(
  684. const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
  685. const AlgoAttribute& positive_attr,
  686. const AlgoAttribute& negative_attr) {
  687. for (auto i : ncb_1g_get_all_algorithms(param)) {
  688. if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) {
  689. if (i->contain_attribute_all(positive_attr) &&
  690. !i->contain_attribute_any(negative_attr)) {
  691. return i;
  692. }
  693. }
  694. }
  695. megdnn_assert(0, "no suitable algorithm found within given workspace limit");
  696. }
  697. ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
  698. get_algorithm_from_desc(const AlgorithmDesc& desc) {
  699. if (!desc.valid()) {
  700. return nullptr;
  701. } else {
  702. switch (desc.handle_type) {
  703. case Handle::HandleType::FALLBACK: {
  704. const auto& map = algo_pack().all_algos_map();
  705. megdnn_assert(map.find(desc) != map.end());
  706. return map.at(desc);
  707. }
  708. #if MEGDNN_AARCH64 || MEGDNN_ARMV7
  709. case Handle::HandleType::ARM_COMMON:
  710. case Handle::HandleType::AARCH64:
  711. case Handle::HandleType::ARMV7:
  712. return arm_common::ConvolutionBackwardDataImpl::get_algo_from_desc(
  713. desc);
  714. #endif
  715. case Handle::HandleType::NAIVE: {
  716. auto algo = static_cast<naive::HandleImpl*>(handle())
  717. ->default_conv_bwd_data_algo();
  718. megdnn_assert(algo->info().desc == desc);
  719. return algo;
  720. }
  721. default:
  722. megdnn_throw("Unknown handle type");
  723. return nullptr;
  724. }
  725. }
  726. }
  727. ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm(
  728. const NCBKernSizeParam& param) {
  729. if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
  730. return algo;
  731. }
  732. if (!m_prev_selected_algo ||
  733. memcmp(&m_prev_selected_algo_sizep, &param, sizeof(NCBKernSizeParam))) {
  734. m_prev_selected_algo = ncb_1g_get_algorithm_heuristic(
  735. param, std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT,
  736. AlgoAttribute::DEFAULT);
  737. m_prev_selected_algo_sizep = param;
  738. }
  739. return m_prev_selected_algo;
  740. }
  741. const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
  742. // fallback version 0
  743. return "FALLBACK_CONVOLUTION_BACKWARD_DATA_IMPL0";
  744. }
  745. // vim: syntax=cpp.doxygen