You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

insert_pad.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/insert_pad.h"
  17. #include <string>
  18. #include <tuple>
  19. #include <vector>
  20. #include "backend/common/session/anf_runtime_algorithm.h"
  21. #include "include/common/utils/anfalgo.h"
  22. #include "common/graph_kernel/graph_kernel_helper.h"
  23. namespace mindspore {
  24. namespace prim {
  25. inline const PrimitivePtr kPrimUnPadAkg = std::make_shared<Primitive>("UnPadAkg");
  26. inline const PrimitivePtr kPrimPadAkg = std::make_shared<Primitive>("PadAkg");
  27. } // namespace prim
  28. namespace graphkernel {
  29. namespace {
  30. using vec = std::vector<size_t>;
  31. // M,N pad 32, K pad 16
  32. auto GetPadShape = [](size_t K, size_t M, size_t N) {
  33. size_t pad_K = ((K - 1) / 16 + 1) * 16;
  34. size_t pad_M = ((M - 1) / 32 + 1) * 32;
  35. size_t pad_N = ((N - 1) / 32 + 1) * 32;
  36. return std::tuple(pad_K, pad_M, pad_N);
  37. };
  38. // Get (K M .. pad_N) when tran_a is true and tran_b is false
  39. auto TransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  40. size_t K, M, N, pad_K, pad_M, pad_N;
  41. size_t size = shape_a.size();
  42. K = shape_a[size - 2];
  43. M = shape_a[size - 1];
  44. N = shape_b[size - 1];
  45. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  46. pad_shape_a->push_back(pad_K);
  47. pad_shape_a->push_back(pad_M);
  48. pad_shape_b->push_back(pad_K);
  49. pad_shape_b->push_back(pad_N);
  50. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  51. };
  52. // Get (K M .. pad_N) when tran_a is true and tran_b is true
  53. auto TransATransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  54. size_t K, M, N, pad_K, pad_M, pad_N;
  55. size_t size = shape_a.size();
  56. K = shape_a[size - 2];
  57. M = shape_a[size - 1];
  58. N = shape_b[size - 2];
  59. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  60. pad_shape_a->push_back(pad_K);
  61. pad_shape_a->push_back(pad_M);
  62. pad_shape_b->push_back(pad_N);
  63. pad_shape_b->push_back(pad_K);
  64. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  65. };
  66. // Get (K M .. pad_N) when tran_a is false and tran_b is true
  67. auto NotTransATransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  68. size_t K, M, N, pad_K, pad_M, pad_N;
  69. size_t size = shape_a.size();
  70. K = shape_a[size - 1];
  71. M = shape_a[size - 2];
  72. N = shape_b[size - 2];
  73. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  74. pad_shape_a->push_back(pad_M);
  75. pad_shape_a->push_back(pad_K);
  76. pad_shape_b->push_back(pad_N);
  77. pad_shape_b->push_back(pad_K);
  78. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  79. };
  80. // Get (K M .. pad_N) when tran_a is false and tran_b is false
  81. auto NotTransANotTransB = [](const vec &shape_a, const vec &shape_b, vec *pad_shape_a, vec *pad_shape_b) {
  82. size_t K, M, N, pad_K, pad_M, pad_N;
  83. size_t size = shape_a.size();
  84. K = shape_a[size - 1];
  85. M = shape_a[size - 2];
  86. N = shape_b[size - 1];
  87. std::tie(pad_K, pad_M, pad_N) = GetPadShape(K, M, N);
  88. pad_shape_a->push_back(pad_M);
  89. pad_shape_a->push_back(pad_K);
  90. pad_shape_b->push_back(pad_K);
  91. pad_shape_b->push_back(pad_N);
  92. return std::tuple(K, M, N, pad_K, pad_M, pad_N);
  93. };
  94. bool IsAkgMatMul(size_t K, size_t M, size_t N) {
  95. if (K > 4096 || M * N * K >= 3e10) {
  96. return false;
  97. }
  98. return true;
  99. }
  100. // Return ture if (K, M, N) need pad
  101. std::tuple<bool, bool, bool> NeedPad(const CNodePtr &matmul, vec *pad_shape_a, vec *pad_shape_b, vec *unpad_shape,
  102. vec *tail_shape_a, vec *tail_shape_b, vec *tail_shape_unpad) {
  103. auto mm_attrs = common::AnfAlgo::GetCNodePrimitive(matmul)->attrs();
  104. if (mm_attrs.count("transpose_a") == 0 || mm_attrs.count("transpose_b") == 0) {
  105. MS_LOG(ERROR) << "attrs transpose_a and transpose_b need to be set in node " << matmul->fullname_with_scope();
  106. return std::tuple(false, false, false);
  107. }
  108. auto tran_a = GetValue<bool>(mm_attrs["transpose_a"]);
  109. auto tran_b = GetValue<bool>(mm_attrs["transpose_b"]);
  110. auto shape_a = AnfAlgo::GetInputDeviceShape(matmul, 0);
  111. auto shape_b = AnfAlgo::GetInputDeviceShape(matmul, 1);
  112. auto size_a = shape_a.size();
  113. for (size_t dim = 0; dim < size_a - 2; ++dim) {
  114. pad_shape_a->push_back(shape_a[dim]);
  115. pad_shape_b->push_back(shape_a[dim]);
  116. unpad_shape->push_back(shape_a[dim]);
  117. tail_shape_a->push_back(0);
  118. tail_shape_b->push_back(0);
  119. tail_shape_unpad->push_back(0);
  120. }
  121. size_t K, M, N, pad_K, pad_M, pad_N;
  122. using kmn = std::tuple<size_t, size_t, size_t, size_t, size_t, size_t>;
  123. using func = std::function<kmn(const vec &, const vec &, vec *, vec *)>;
  124. func f = tran_a ? (tran_b ? TransATransB : TransANotTransB) : (tran_b ? NotTransATransB : NotTransANotTransB);
  125. std::tie(K, M, N, pad_K, pad_M, pad_N) = f(shape_a, shape_b, pad_shape_a, pad_shape_b);
  126. // Donot Pad for cublas operator
  127. if (!IsAkgMatMul(K, M, N)) {
  128. SetNodeAttrSafely("Akg", MakeValue(false), matmul);
  129. return std::tuple(false, false, false);
  130. }
  131. SetNodeAttrSafely("Akg", MakeValue(true), matmul);
  132. unpad_shape->push_back(M);
  133. unpad_shape->push_back(N);
  134. tail_shape_unpad->push_back(pad_M - M);
  135. tail_shape_unpad->push_back(pad_N - N);
  136. tail_shape_a->push_back(pad_shape_a->at(size_a - 2) - shape_a[size_a - 2]);
  137. tail_shape_a->push_back(pad_shape_a->at(size_a - 1) - shape_a[size_a - 1]);
  138. tail_shape_b->push_back(pad_shape_b->at(size_a - 2) - shape_b[size_a - 2]);
  139. tail_shape_b->push_back(pad_shape_b->at(size_a - 1) - shape_b[size_a - 1]);
  140. return std::tuple(pad_K != K, pad_M != M, pad_N != N);
  141. }
  142. // Insert pad for A if left is true, insert pad for B if left is false
  143. void InsertPad(const CNodePtr &matmul, const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, bool left,
  144. const vec &pad_shape, const vec &tail_shape) {
  145. size_t input_index = left ? 1 : 2;
  146. AnfNodePtrList pad_inp = {NewValueNode(prim::kPrimPadAkg), matmul->input(input_index)};
  147. auto pad_cnode = func_graph->NewCNode(pad_inp);
  148. func_graph->AddNode(pad_cnode);
  149. ShapeVector tail;
  150. (void)tail.insert(tail.begin(), tail_shape.begin(), tail_shape.end());
  151. ShapeVector head(tail_shape.size(), 0);
  152. SetNodeAttrSafely("head", MakeValue(head), pad_cnode);
  153. SetNodeAttrSafely("tail", MakeValue(tail), pad_cnode);
  154. SetNodeAttrSafely("pad_val", MakeValue(std::make_shared<Int32Imm>(0)), pad_cnode);
  155. std::vector<TypeId> pad_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(matmul, 0)};
  156. ShapeVector abs_shape;
  157. (void)abs_shape.insert(abs_shape.begin(), pad_shape.begin(), pad_shape.end());
  158. auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(abs_shape));
  159. auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(pad_type[0]), abs_shape_ptr);
  160. pad_cnode->set_abstract(abstract);
  161. pad_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
  162. std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(matmul);
  163. std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(matmul);
  164. std::vector<std::string> pad_inp_formats = {input_formats.front()};
  165. std::vector<TypeId> pad_inp_types = {input_types.front()};
  166. std::vector<std::string> pad_output_formats = {input_formats.front()};
  167. std::vector<TypeId> output_types = {input_types.front()};
  168. auto graph_sel_info = BuildSelectKernelBuildInfo(pad_inp_formats, pad_inp_types, pad_output_formats, output_types);
  169. AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, pad_cnode.get());
  170. matmul->set_input(input_index, pad_cnode);
  171. }
  172. // unpad_shape is [batch, M, N], tail_shape is [0, pad_M - M, pad_N - N]
  173. void InsertUnpad(const CNodePtr &matmul, const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
  174. const vec &unpad_shape, const vec &tail_shape) {
  175. AnfNodePtrList unpad_inp = {NewValueNode(prim::kPrimUnPadAkg), matmul};
  176. auto unpad_cnode = func_graph->NewCNode(unpad_inp);
  177. func_graph->AddNode(unpad_cnode);
  178. ShapeVector tail;
  179. (void)tail.insert(tail.begin(), tail_shape.begin(), tail_shape.end());
  180. SetNodeAttrSafely("tail", MakeValue(tail), unpad_cnode);
  181. std::vector<TypeId> unpad_type = {common::AnfAlgo::GetOutputInferDataType(matmul, 0)};
  182. ShapeVector abs_shape;
  183. (void)abs_shape.insert(abs_shape.begin(), unpad_shape.begin(), unpad_shape.end());
  184. auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(abs_shape));
  185. auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(unpad_type[0]), abs_shape_ptr);
  186. unpad_cnode->set_abstract(abstract);
  187. unpad_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
  188. std::vector<std::string> unpad_input_format = {AnfAlgo::GetOutputFormat(matmul, 0)};
  189. std::vector<TypeId> unpad_input_type = AnfAlgo::GetAllOutputDeviceTypes(matmul);
  190. std::vector<std::string> unpad_output_format = {unpad_input_format.front()};
  191. std::vector<TypeId> unpad_output_type = {unpad_input_type.front()};
  192. auto graph_sel_info =
  193. BuildSelectKernelBuildInfo(unpad_input_format, unpad_input_type, unpad_output_format, unpad_output_type);
  194. AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, unpad_cnode.get());
  195. (void)mng->Replace(matmul, unpad_cnode);
  196. }
  197. // Update matmul's Abatract and BuildInfo as M or N is changed
  198. void UpdateMatmulInfo(const AnfNodePtr &matmul_node, const vec &unpad_shape, const vec &tail_shape) {
  199. ShapeVector abs_shape;
  200. for (size_t i = 0; i < unpad_shape.size(); ++i) {
  201. abs_shape.push_back(unpad_shape[i] + tail_shape[i]);
  202. }
  203. auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(abs_shape));
  204. TypeId abs_type = common::AnfAlgo::GetOutputInferDataType(matmul_node, 0);
  205. auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(abs_type), abs_shape_ptr);
  206. matmul_node->set_abstract(abstract);
  207. std::vector<std::string> input_formats = AnfAlgo::GetAllInputFormats(matmul_node);
  208. std::vector<TypeId> input_types = AnfAlgo::GetAllInputDeviceTypes(matmul_node);
  209. std::vector<std::string> output_formats = AnfAlgo::GetAllOutputFormats(matmul_node);
  210. std::vector<TypeId> output_types = AnfAlgo::GetAllOutputDeviceTypes(matmul_node);
  211. auto graph_sel_info =
  212. BuildSelectKernelBuildInfo(input_formats, input_types, output_formats, output_types, matmul_node);
  213. AnfAlgo::SetSelectKernelBuildInfo(graph_sel_info, matmul_node.get());
  214. }
  215. bool InsertPadUnpad(const FuncGraphPtr &func_graph) {
  216. auto mng = func_graph->manager();
  217. MS_EXCEPTION_IF_NULL(mng);
  218. auto todos = TopoSort(func_graph->get_return());
  219. bool changed = false;
  220. for (const auto &n : todos) {
  221. if (!common::AnfAlgo::CheckPrimitiveType(n, prim::kPrimMatMul)) continue;
  222. auto mm_cnode = n->cast<CNodePtr>();
  223. vec pad_shape_a, pad_shape_b, tail_shape_a, tail_shape_b, tail_shape_unpad, unpad_shape;
  224. bool pad_K{false}, pad_M{false}, pad_N{false};
  225. std::tie(pad_K, pad_M, pad_N) =
  226. NeedPad(mm_cnode, &pad_shape_a, &pad_shape_b, &unpad_shape, &tail_shape_a, &tail_shape_b, &tail_shape_unpad);
  227. if (!pad_K && !pad_M && !pad_N) continue;
  228. if (pad_K || pad_M) {
  229. InsertPad(mm_cnode, func_graph, mng, true, pad_shape_a, tail_shape_a);
  230. }
  231. if (pad_K || pad_N) {
  232. InsertPad(mm_cnode, func_graph, mng, false, pad_shape_b, tail_shape_b);
  233. }
  234. if (pad_M || pad_N) {
  235. UpdateMatmulInfo(mm_cnode, unpad_shape, tail_shape_unpad);
  236. InsertUnpad(mm_cnode, func_graph, mng, unpad_shape, tail_shape_unpad);
  237. }
  238. changed = true;
  239. }
  240. return changed;
  241. }
  242. } // namespace
  243. /* MatMul
  244. *
  245. * C = MatMul(A, B)
  246. * ------>
  247. * A_pad = PadAkg(A)
  248. * B_pad = PadAkg(B)
  249. * C_pad = MatMul(A_pad, B_pad)
  250. * C = UnPadAkg(C_pad)
  251. *
  252. */
  253. bool InsertPadOps::Run(const FuncGraphPtr &func_graph) {
  254. MS_EXCEPTION_IF_NULL(func_graph);
  255. auto mng = func_graph->manager();
  256. if (mng == nullptr) {
  257. mng = Manage(func_graph, true);
  258. func_graph->set_manager(mng);
  259. }
  260. auto changed = false;
  261. auto nodes = TopoSort(func_graph->get_return());
  262. for (auto node : nodes) {
  263. if (!common::AnfAlgo::IsGraphKernel(node)) continue;
  264. auto graph_kernel_fg = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
  265. MS_EXCEPTION_IF_NULL(graph_kernel_fg);
  266. changed = InsertPadUnpad(graph_kernel_fg) || changed;
  267. }
  268. if (changed) {
  269. mng->RemoveRoots();
  270. mng->KeepRoots({func_graph});
  271. }
  272. return changed;
  273. }
  274. } // namespace graphkernel
  275. } // namespace mindspore