You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

insert_pad.cc 13 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/insert_pad.h"
  17. #include <string>
  18. #include <tuple>
  19. #include <vector>
  20. #include "backend/common/session/anf_runtime_algorithm.h"
  21. #include "include/common/utils/anfalgo.h"
  22. #include "common/graph_kernel/graph_kernel_helper.h"
  23. namespace mindspore {
  24. namespace prim {
  25. inline const PrimitivePtr kPrimUnPadAkg = std::make_shared<Primitive>("UnPadAkg");
  26. inline const PrimitivePtr kPrimPadAkg = std::make_shared<Primitive>("PadAkg");
  27. } // namespace prim
  28. namespace graphkernel {
  29. namespace {
  30. using vec = std::vector<size_t>;
  31. constexpr size_t MAX_PER_DIM_SHAPE = 4096;
  32. constexpr int64_t MAX_ALL_SHAPE = static_cast<int64_t>(3e10);
  33. // M,N pad 32, K pad 16
  34. const auto GetPadShape = [](size_t K, size_t M, size_t N) {
  35. size_t pad_K = ((K - 1) / 16 + 1) * 16;
  36. size_t pad_M = ((M - 1) / 32 + 1) * 32;
  37. size_t pad_N = ((N - 1) / 32 + 1) * 32;
  38. return std::tuple(pad_K, pad_M, pad_N);
  39. };
  40. // Get (K M .. pad_N) when tran_a is true and tran_b is false
  41. const auto TransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  42. size_t K, M, N, pad_K, pad_M, pad_N;
  43. size_t size = shape_a.size();
  44. K = shape_a[size - 2];
  45. M = shape_a[size - 1];
  46. N = shape_b[size - 1];
  47. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  48. pad_shape_a->push_back(pad_K);
  49. pad_shape_a->push_back(pad_M);
  50. pad_shape_b->push_back(pad_K);
  51. pad_shape_b->push_back(pad_N);
  52. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  53. };
  54. // Get (K M .. pad_N) when tran_a is true and tran_b is true
  55. const auto TransATransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  56. size_t K, M, N, pad_K, pad_M, pad_N;
  57. size_t size = shape_a.size();
  58. K = shape_a[size - 2];
  59. M = shape_a[size - 1];
  60. N = shape_b[size - 2];
  61. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  62. pad_shape_a->push_back(pad_K);
  63. pad_shape_a->push_back(pad_M);
  64. pad_shape_b->push_back(pad_N);
  65. pad_shape_b->push_back(pad_K);
  66. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  67. };
  68. // Get (K M .. pad_N) when tran_a is false and tran_b is true
  69. const auto NotTransATransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  70. size_t K, M, N, pad_K, pad_M, pad_N;
  71. size_t size = shape_a.size();
  72. K = shape_a[size - 1];
  73. M = shape_a[size - 2];
  74. N = shape_b[size - 2];
  75. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  76. pad_shape_a->push_back(pad_M);
  77. pad_shape_a->push_back(pad_K);
  78. pad_shape_b->push_back(pad_N);
  79. pad_shape_b->push_back(pad_K);
  80. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  81. };
  82. // Get (K M .. pad_N) when tran_a is false and tran_b is false
  83. const auto NotTransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  84. size_t K, M, N, pad_K, pad_M, pad_N;
  85. size_t size = shape_a.size();
  86. K = shape_a[size - 1];
  87. M = shape_a[size - 2];
  88. N = shape_b[size - 1];
  89. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  90. pad_shape_a->push_back(pad_M);
  91. pad_shape_a->push_back(pad_K);
  92. pad_shape_b->push_back(pad_K);
  93. pad_shape_b->push_back(pad_N);
  94. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  95. };
  96. bool IsAkgMatMul(size_t K, size_t M, size_t N) {
  97. if (K > MAX_PER_DIM_SHAPE ||
  98. (static_cast<int64_t>(M) * static_cast<int64_t>(N) * static_cast<int64_t>(K)) >= MAX_ALL_SHAPE) {
  99. return false;
  100. }
  101. return true;
  102. }
  103. // Return ture if (K, M, N) need pad
  104. std::tuple<bool, bool, bool> NeedPad(const CNodePtr &matmul, vec *pad_shape_a, vec *pad_shape_b, vec *unpad_shape,
  105. vec *tail_shape_a, vec *tail_shape_b, vec *tail_shape_unpad) {
  106. auto mm_attrs = common::AnfAlgo::GetCNodePrimitive(matmul)->attrs();
  107. if (mm_attrs.count("transpose_a") == 0 || mm_attrs.count("transpose_b") == 0) {
  108. MS_LOG(ERROR) << "attrs transpose_a and transpose_b need to be set in node " << matmul->fullname_with_scope();
  109. return std::tuple(false, false, false);
  110. }
  111. auto tran_a = GetValue<bool>(mm_attrs["transpose_a"]);
  112. auto tran_b = GetValue<bool>(mm_attrs["transpose_b"]);
  113. auto shape_a = AnfAlgo::GetInputDeviceShape(matmul, 0);
  114. auto shape_b = AnfAlgo::GetInputDeviceShape(matmul, 1);
  115. auto size_a = shape_a.size();
  116. for (size_t dim = 0; dim < size_a - 2; ++dim) {
  117. pad_shape_a->push_back(shape_a[dim]);
  118. pad_shape_b->push_back(shape_a[dim]);
  119. unpad_shape->push_back(shape_a[dim]);
  120. tail_shape_a->push_back(0);
  121. tail_shape_b->push_back(0);
  122. tail_shape_unpad->push_back(0);
  123. }
  124. size_t K, M, N, pad_K, pad_M, pad_N;
  125. using kmn = std::tuple<size_t, size_t, size_t, size_t, size_t, size_t>;
  126. using func = std::function<kmn(const vec &, const vec &, vec *, vec *)>;
  127. func f = tran_a ? (tran_b ? TransATransB : TransANotTransB) : (tran_b ? NotTransATransB : NotTransANotTransB);
  128. std::tie(K, M, N, pad_K, pad_M, pad_N) = f(shape_a, shape_b, pad_shape_a, pad_shape_b);
  129. // Donot Pad for cublas operator
  130. if (!IsAkgMatMul(K, M, N)) {
  131. SetNodeAttrSafely("Akg", MakeValue(false), matmul);
  132. return std::tuple(false, false, false);
  133. }
  134. SetNodeAttrSafely("Akg", MakeValue(true), matmul);
  135. unpad_shape->push_back(M);
  136. unpad_shape->push_back(N);
  137. tail_shape_unpad->push_back(pad_M - M);
  138. tail_shape_unpad->push_back(pad_N - N);
  139. tail_shape_a->push_back(pad_shape_a->at(size_a - 2) - shape_a[size_a - 2]);
  140. tail_shape_a->push_back(pad_shape_a->at(size_a - 1) - shape_a[size_a - 1]);
  141. tail_shape_b->push_back(pad_shape_b->at(size_a - 2) - shape_b[size_a - 2]);
  142. tail_shape_b->push_back(pad_shape_b->at(size_a - 1) - shape_b[size_a - 1]);
  143. return std::tuple(pad_K != K, pad_M != M, pad_N != N);
  144. }
  145. // Insert pad for A if left is true, insert pad for B if left is false
  146. void InsertPad(const CNodePtr &matmul, const FuncGraphPtr &func_graph, bool left, const vec &pad_shape,
  147. const vec &tail_shape) {
  148. size_t input_index = left ? 1 : 2;
  149. AnfNodePtrList pad_inp = {NewValueNode(prim::kPrimPadAkg), matmul->input(input_index)};
  150. auto pad_cnode = func_graph->NewCNode(pad_inp);
  151. func_graph->AddNode(pad_cnode);
  152. ShapeVector tail;
  153. (void)tail.insert(tail.begin(), tail_shape.begin(), tail_shape.end());
  154. ShapeVector head(tail_shape.size(), 0);
  155. SetNodeAttrSafely("head", MakeValue(head), pad_cnode);
  156. SetNodeAttrSafely("tail", MakeValue(tail), pad_cnode);
  157. SetNodeAttrSafely("pad_val", MakeValue(std::make_shared<Int32Imm>(0)), pad_cnode);
  158. std::vector<TypeId> pad_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(matmul, 0)};
  159. ShapeVector abs_shape;
  160. (void)abs_shape.insert(abs_shape.begin(), pad_shape.begin(), pad_shape.end());
  161. auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(abs_shape));
  162. auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(pad_type[0]), abs_shape_ptr);
  163. pad_cnode->set_abstract(abstract);
  164. pad_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
  165. std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(matmul);
  166. std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(matmul);
  167. std::vector<std::string> pad_inp_formats = {input_formats.front()};
  168. std::vector<TypeId> pad_inp_types = {input_types.front()};
  169. std::vector<std::string> pad_output_formats = {input_formats.front()};
  170. std::vector<TypeId> output_types = {input_types.front()};
  171. auto graph_sel_info = BuildSelectKernelBuildInfo(pad_inp_formats, pad_inp_types, pad_output_formats, output_types);
  172. AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, pad_cnode.get());
  173. matmul->set_input(input_index, pad_cnode);
  174. }
  175. // unpad_shape is [batch, M, N], tail_shape is [0, pad_M - M, pad_N - N]
  176. void InsertUnpad(const CNodePtr &matmul, const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
  177. const vec &unpad_shape, const vec &tail_shape) {
  178. AnfNodePtrList unpad_inp = {NewValueNode(prim::kPrimUnPadAkg), matmul};
  179. auto unpad_cnode = func_graph->NewCNode(unpad_inp);
  180. func_graph->AddNode(unpad_cnode);
  181. ShapeVector tail;
  182. (void)tail.insert(tail.begin(), tail_shape.begin(), tail_shape.end());
  183. SetNodeAttrSafely("tail", MakeValue(tail), unpad_cnode);
  184. std::vector<TypeId> unpad_type = {common::AnfAlgo::GetOutputInferDataType(matmul, 0)};
  185. ShapeVector abs_shape;
  186. (void)abs_shape.insert(abs_shape.begin(), unpad_shape.begin(), unpad_shape.end());
  187. auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(abs_shape));
  188. auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(unpad_type[0]), abs_shape_ptr);
  189. unpad_cnode->set_abstract(abstract);
  190. unpad_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
  191. std::vector<std::string> unpad_input_format = {AnfAlgo::GetOutputFormat(matmul, 0)};
  192. std::vector<TypeId> unpad_input_type = AnfAlgo::GetAllOutputDeviceTypes(matmul);
  193. std::vector<std::string> unpad_output_format = {unpad_input_format.front()};
  194. std::vector<TypeId> unpad_output_type = {unpad_input_type.front()};
  195. auto graph_sel_info =
  196. BuildSelectKernelBuildInfo(unpad_input_format, unpad_input_type, unpad_output_format, unpad_output_type);
  197. AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, unpad_cnode.get());
  198. (void)mng->Replace(matmul, unpad_cnode);
  199. }
  200. // Update matmul's Abatract and BuildInfo as M or N is changed
  201. void UpdateMatmulInfo(const AnfNodePtr &matmul_node, const vec &unpad_shape, const vec &tail_shape) {
  202. ShapeVector abs_shape;
  203. for (size_t i = 0; i < unpad_shape.size(); ++i) {
  204. abs_shape.push_back(unpad_shape[i] + tail_shape[i]);
  205. }
  206. auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(abs_shape));
  207. TypeId abs_type = common::AnfAlgo::GetOutputInferDataType(matmul_node, 0);
  208. auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(abs_type), abs_shape_ptr);
  209. matmul_node->set_abstract(abstract);
  210. std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(matmul_node);
  211. std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(matmul_node);
  212. std::vector<std::string> output_formats = AnfAlgo::GetAllOutputFormats(matmul_node);
  213. std::vector<TypeId> output_types = AnfAlgo::GetAllOutputDeviceTypes(matmul_node);
  214. auto graph_sel_info = BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types,
  215. AnfAlgo::GetProcessor(matmul_node));
  216. AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, matmul_node.get());
  217. }
  218. bool InsertPadUnpad(const FuncGraphPtr &func_graph) {
  219. auto mng = func_graph->manager();
  220. MS_EXCEPTION_IF_NULL(mng);
  221. auto todos = TopoSort(func_graph->get_return());
  222. bool changed = false;
  223. for (const auto &n : todos) {
  224. if (!common::AnfAlgo::CheckPrimitiveType(n, prim::kPrimMatMul)) continue;
  225. auto mm_cnode = n->cast<CNodePtr>();
  226. vec pad_shape_a, pad_shape_b, tail_shape_a, tail_shape_b, tail_shape_unpad, unpad_shape;
  227. bool pad_K{false}, pad_M{false}, pad_N{false};
  228. std::tie(pad_K, pad_M, pad_N) =
  229. NeedPad(mm_cnode, &pad_shape_a, &pad_shape_b, &unpad_shape, &tail_shape_a, &tail_shape_b, &tail_shape_unpad);
  230. if (!pad_K && !pad_M && !pad_N) continue;
  231. if (pad_K || pad_M) {
  232. InsertPad(mm_cnode, func_graph, true, pad_shape_a, tail_shape_a);
  233. }
  234. if (pad_K || pad_N) {
  235. InsertPad(mm_cnode, func_graph, false, pad_shape_b, tail_shape_b);
  236. }
  237. if (pad_M || pad_N) {
  238. UpdateMatmulInfo(mm_cnode, unpad_shape, tail_shape_unpad);
  239. InsertUnpad(mm_cnode, func_graph, mng, unpad_shape, tail_shape_unpad);
  240. }
  241. changed = true;
  242. }
  243. return changed;
  244. }
  245. } // namespace
  246. /* MatMul
  247. *
  248. * C = MatMul(A, B)
  249. * ------>
  250. * A_pad = PadAkg(A)
  251. * B_pad = PadAkg(B)
  252. * C_pad = MatMul(A_pad, B_pad)
  253. * C = UnPadAkg(C_pad)
  254. *
  255. */
  256. bool InsertPadOps::Run(const FuncGraphPtr &func_graph) {
  257. MS_EXCEPTION_IF_NULL(func_graph);
  258. auto mng = func_graph->manager();
  259. if (mng == nullptr) {
  260. mng = Manage(func_graph, true);
  261. func_graph->set_manager(mng);
  262. }
  263. auto changed = false;
  264. auto nodes = TopoSort(func_graph->get_return());
  265. for (auto node : nodes) {
  266. if (!common::AnfAlgo::IsGraphKernel(node)) continue;
  267. auto graph_kernel_fg = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
  268. MS_EXCEPTION_IF_NULL(graph_kernel_fg);
  269. changed = InsertPadUnpad(graph_kernel_fg) || changed;
  270. }
  271. if (changed) {
  272. mng->RemoveRoots();
  273. mng->KeepRoots({func_graph});
  274. }
  275. return changed;
  276. }
  277. } // namespace graphkernel
  278. } // namespace mindspore