You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ms_device_shape_transfer.cc 75 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/ms_device_shape_transfer.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <utility>
  20. #include <algorithm>
  21. namespace mindspore {
  22. namespace trans {
  23. const int b1 = 1;
  24. const int b2 = 2;
  25. const int b4 = 4;
  26. const int b8 = 8;
  27. const int64_t kCubeSize = 16;
  28. const int64_t kCube16 = kCubeSize;
  29. const int64_t kCube32 = 32;
  30. const int64_t kCube64 = 64;
  31. const int64_t kCubeSize_C04 = 4;
  32. const int64_t kNiSize = 16;
  33. constexpr int kDims2 = 2;
  34. constexpr int64_t k4 = 4;
  35. static const std::set<TypeId> C0_64 = {kNumberTypeInt4};
  36. static const std::set<TypeId> C0_32 = {kNumberTypeUInt8, kNumberTypeInt8};
  37. namespace {
  38. bool HasShapeDynamic(const ShapeVector &shape_list) {
  39. return std::any_of(shape_list.begin(), shape_list.end(), [](int64_t v) { return v == abstract::Shape::SHP_ANY; });
  40. }
  41. template <typename T>
  42. T Gcd(T a, T b) {
  43. if (b == 0) {
  44. return 0;
  45. }
  46. T c = b;
  47. while (a % b != 0) {
  48. c = a % b;
  49. a = b;
  50. b = c;
  51. }
  52. return c;
  53. }
  54. template <typename T>
  55. T Lcm(T a, T b) {
  56. if (b == 0) {
  57. return 0;
  58. }
  59. T ret = (a * b) / (Gcd(a, b));
  60. return ret;
  61. }
  62. template <typename T>
  63. T DivCeil(T n1, T n2) {
  64. if (n2 != 0) {
  65. return (n1 + n2 - 1) / n2;
  66. }
  67. return 0;
  68. }
  69. template <typename T>
  70. bool CheckDims(const std::vector<T> &shape) {
  71. if (shape.size() != kNchwDims) {
  72. MS_LOG(ERROR) << "Host shape dims should be 4";
  73. return false;
  74. }
  75. return true;
  76. }
  77. int64_t GetCubeSizeByType(const TypeId &data_type) {
  78. if (C0_32.find(data_type) != C0_32.end()) {
  79. return kCube32;
  80. }
  81. if (C0_64.find(data_type) != C0_64.end()) {
  82. return kCube64;
  83. }
  84. return kCube16;
  85. }
  86. RangePair PaddingRangeTo5D(const RangePair &ori_range) {
  87. RangePair dst_range(kNcdhw, std::pair<int64_t, int64_t>(1, 1));
  88. switch (ori_range.size()) {
  89. case N_ncdhw:
  90. return ori_range;
  91. case C_ncdhw:
  92. dst_range[C_ncdhw] = ori_range[N_ncdhw];
  93. break;
  94. case D_ncdhw:
  95. dst_range[C_ncdhw] = ori_range[N_ncdhw];
  96. dst_range[D_ncdhw] = ori_range[C_ncdhw];
  97. break;
  98. case H_ncdhw:
  99. dst_range[C_ncdhw] = ori_range[N_ncdhw];
  100. dst_range[D_ncdhw] = ori_range[C_ncdhw];
  101. dst_range[H_ncdhw] = ori_range[D_ncdhw];
  102. break;
  103. case W_ncdhw:
  104. dst_range[C_ncdhw] = ori_range[N_ncdhw];
  105. dst_range[D_ncdhw] = ori_range[C_ncdhw];
  106. dst_range[H_ncdhw] = ori_range[D_ncdhw];
  107. dst_range[W_ncdhw] = ori_range[H_ncdhw];
  108. break;
  109. default:
  110. MS_LOG(EXCEPTION) << "Unexpected shape size = " << ori_range.size();
  111. }
  112. return dst_range;
  113. }
  114. RangePair PaddingRangeTo4D(const RangePair &ori_range) {
  115. RangePair dst_range(kNchwDims, std::pair<int64_t, int64_t>(1, 1));
  116. switch (ori_range.size()) {
  117. case kN:
  118. return dst_range;
  119. case kC:
  120. dst_range[kC] = ori_range[kN];
  121. break;
  122. case kH:
  123. dst_range[kC] = ori_range[kN];
  124. dst_range[kH] = ori_range[kC];
  125. break;
  126. case kW:
  127. dst_range[kC] = ori_range[kN];
  128. dst_range[kH] = ori_range[kC];
  129. dst_range[kW] = ori_range[kH];
  130. break;
  131. case kNchwDims:
  132. (void)std::copy(ori_range.begin(), ori_range.end(), dst_range.begin());
  133. break;
  134. default:
  135. MS_LOG(EXCEPTION) << "Unexpected range size: " << ori_range.size();
  136. }
  137. return dst_range;
  138. }
  139. } // namespace
  140. void StringToAxisVector4D(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
  141. MS_EXCEPTION_IF_NULL(reshape_type_vec);
  142. if (reshape_type_str.empty()) {
  143. MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
  144. return;
  145. }
  146. for (const auto &c : reshape_type_str) {
  147. switch (c) {
  148. case 'N':
  149. reshape_type_vec->push_back(N);
  150. break;
  151. case 'C':
  152. reshape_type_vec->push_back(C);
  153. break;
  154. case 'H':
  155. reshape_type_vec->push_back(H);
  156. break;
  157. case 'W':
  158. reshape_type_vec->push_back(W);
  159. break;
  160. default:
  161. MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
  162. }
  163. }
  164. }
  165. void StringToAxisVector5D(const std::string &reshape_type_str, std::vector<Axis5D> *reshape_type_vec) {
  166. MS_EXCEPTION_IF_NULL(reshape_type_vec);
  167. if (reshape_type_str.empty()) {
  168. MS_LOG(DEBUG) << "Reshape type str is empty, no need padding.";
  169. return;
  170. }
  171. for (const auto &c : reshape_type_str) {
  172. switch (c) {
  173. case 'N':
  174. reshape_type_vec->push_back(N_ncdhw);
  175. break;
  176. case 'C':
  177. reshape_type_vec->push_back(C_ncdhw);
  178. break;
  179. case 'D':
  180. reshape_type_vec->push_back(D_ncdhw);
  181. break;
  182. case 'H':
  183. reshape_type_vec->push_back(H_ncdhw);
  184. break;
  185. case 'W':
  186. reshape_type_vec->push_back(W_ncdhw);
  187. break;
  188. default:
  189. MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type.";
  190. }
  191. }
  192. }
  193. bool IsNeedPadding(const std::string &format, size_t shape_size) {
  194. if (shape_size == 0) {
  195. return false;
  196. }
  197. if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW ||
  198. kNoPaddingFormatSet.find(format) != kNoPaddingFormatSet.end()) {
  199. return false;
  200. } else if (shape_size < kNchwDims) {
  201. return true;
  202. }
  203. return false;
  204. }
  205. ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
  206. MS_EXCEPTION_IF_NULL(node);
  207. ShapeVector shape;
  208. std::vector<size_t> host_shape;
  209. if (node->isa<ValueNode>()) {
  210. auto value_node = node->cast<ValueNodePtr>();
  211. MS_EXCEPTION_IF_NULL(value_node);
  212. auto node_value = value_node->value();
  213. MS_EXCEPTION_IF_NULL(node_value);
  214. auto tensor = node_value->cast<tensor::TensorPtr>();
  215. if (tensor == nullptr) {
  216. MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
  217. }
  218. auto shape_temp = tensor->shape();
  219. (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), LongToSize);
  220. if (host_shape.empty()) {
  221. host_shape.push_back(1);
  222. }
  223. } else {
  224. host_shape = AnfAlgo::GetOutputInferShape(node, index);
  225. }
  226. auto format = AnfAlgo::GetOutputFormat(node, index);
  227. if (IsNeedPadding(format, host_shape.size())) {
  228. host_shape = PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index));
  229. }
  230. std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
  231. return shape;
  232. }
  233. bool TransDataType(const TypeIdArgs &args, void *result) {
  234. DataTypeTransfer dataTypeTransfer;
  235. return dataTypeTransfer.TransDataType(args, result);
  236. }
  237. bool TransFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index) {
  238. FormatTransfer formatTransfer;
  239. return formatTransfer.TransDataByFormat(args, result, node, index, true);
  240. }
  241. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, int64_t groups) {
  242. FormatTransfer formatTransfer;
  243. return formatTransfer.TransDataBackwordCore(args, result, groups);
  244. }
  245. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index) {
  246. FormatTransfer formatTransfer;
  247. return formatTransfer.TransDataByFormat(args, result, node, index, false);
  248. }
  249. /**###################### DATA TYPE TRANS ################################*/
  250. void CheckMemSize(const TypeIdArgs &args) {
  251. auto src_type_size = abstract::TypeIdSize(args.src_data_type);
  252. auto dst_type_size = abstract::TypeIdSize(args.dst_data_type);
  253. if (src_type_size < 1 || dst_type_size < 1) {
  254. MS_LOG(EXCEPTION) << "Invalid src or dst data type. Src type: " << TypeIdLabel(args.src_data_type)
  255. << ", dst type: " << TypeIdLabel(args.dst_data_type);
  256. }
  257. if (SizeToLong(args.data_size / src_type_size) != args.src_shape_size) {
  258. MS_LOG(EXCEPTION) << "Invalid src or dst data shape size. Src shape size: " << args.src_shape_size
  259. << ", dst shape size: " << args.data_size / src_type_size;
  260. }
  261. }
  262. template <typename SrcT, typename DstT>
  263. void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const int64_t data_size) {
  264. CheckMemSize(args);
  265. for (int64_t idx = 0; idx != data_size; idx++) {
  266. SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
  267. static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
  268. }
  269. }
  270. template <typename SrcT>
  271. void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const int64_t data_size) {
  272. CheckMemSize(args);
  273. auto src_data = static_cast<const SrcT *>(args.data);
  274. auto half_data = static_cast<float16 *>(dst);
  275. for (int64_t i = 0; i < data_size; i++) {
  276. half_data[i] = float16(src_data[i]);
  277. }
  278. }
  279. bool DataTypeTransfer::CastKernel(const TypeIdArgs &args, void *dst, int64_t data_size, DataTypeTransMode mode) {
  280. using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const int64_t)>;
  281. const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
  282. {DataTypeTransMode::FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
  283. {DataTypeTransMode::FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  284. {DataTypeTransMode::FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
  285. {DataTypeTransMode::FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  286. {DataTypeTransMode::FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  287. {DataTypeTransMode::FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
  288. {DataTypeTransMode::FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  289. {DataTypeTransMode::FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
  290. {DataTypeTransMode::FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
  291. {DataTypeTransMode::FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
  292. {DataTypeTransMode::FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
  293. {DataTypeTransMode::FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
  294. {DataTypeTransMode::FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
  295. {DataTypeTransMode::FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
  296. {DataTypeTransMode::FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
  297. {DataTypeTransMode::FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
  298. {DataTypeTransMode::FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
  299. {DataTypeTransMode::FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
  300. {DataTypeTransMode::FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
  301. {DataTypeTransMode::FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
  302. {DataTypeTransMode::FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
  303. {DataTypeTransMode::FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
  304. {DataTypeTransMode::FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>}};
  305. if (mode == DataTypeTransMode::FROM_FLOAT_TO_FLOAT16) {
  306. device::FloatToHalf(dst, args.data, data_size);
  307. return true;
  308. } else if (mode == DataTypeTransMode::FROM_FLOAT16_TO_FLOAT) {
  309. device::HalfToFloat(dst, args.data, data_size);
  310. return true;
  311. }
  312. auto iter = cast_kernel_map.find(mode);
  313. if (iter != cast_kernel_map.end()) {
  314. iter->second(args, dst, data_size);
  315. return true;
  316. } else {
  317. MS_LOG(ERROR) << "Can not find a datatype trans function. Src type :" << TypeIdLabel(args.src_data_type)
  318. << ", dst_type:" << TypeIdLabel(args.dst_data_type);
  319. return false;
  320. }
  321. }
  322. bool DataTypeTransfer::TransDataType(const TypeIdArgs &args, void *result) {
  323. MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_data_type) << " to "
  324. << TypeIdLabel(args.dst_data_type);
  325. MS_EXCEPTION_IF_NULL(result);
  326. std::pair<TypeId, TypeId> type_info(args.src_data_type, args.dst_data_type);
  327. auto iter = mode_map.find(type_info);
  328. if (iter == mode_map.end()) {
  329. MS_LOG(ERROR) << "Can not find a datatype trans type. src_type :" << TypeIdLabel(args.src_data_type)
  330. << ", dst_type:" << TypeIdLabel(args.dst_data_type);
  331. return false;
  332. }
  333. auto trans_mode = iter->second;
  334. if (!CastKernel(args, result, args.src_shape_size, trans_mode)) {
  335. MS_LOG(ERROR) << "Failed to trans datatype. Src: " << TypeIdLabel(args.src_data_type)
  336. << ", dst: " << TypeIdLabel(args.dst_data_type);
  337. return false;
  338. }
  339. return true;
  340. }
  341. /**###################### DATA SHAPE TRANS ################################*/
  342. ShapeVector DeviceShapeTransfer::GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format,
  343. const AnfNodePtr &node, size_t index, const TypeId &type,
  344. bool is_output) {
  345. auto dev_shape = GetFixedDeviceShape(shape, node, index, is_output);
  346. if (dev_shape.has_value()) {
  347. return dev_shape.value();
  348. }
  349. int64_t groups = 1;
  350. if (format == kOpFormat_FRAC_Z) {
  351. groups = AnfAlgo::GetAttrGroups(node, index);
  352. }
  353. ShapeVector input_hidden_size = {kAlign16, kAlign16};
  354. if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
  355. input_hidden_size = GetAttrInputAndHiddenSize(node);
  356. }
  357. if (node != nullptr) {
  358. MS_LOG(DEBUG) << "Start trans infer shape to device shape for node: " << node->DebugString()
  359. << ", format: " << format;
  360. }
  361. return TransCore(shape, format, type, groups, input_hidden_size);
  362. }
  363. ShapeVector DeviceShapeTransfer::GetDeviceShapeByFormat(const ShapeVector &shape, const std::string &format,
  364. const TypeId &type, int64_t groups,
  365. const ShapeVector &input_hidden_size) {
  366. return TransCore(shape, format, type, groups, input_hidden_size);
  367. }
  368. std::optional<ShapeVector> DeviceShapeTransfer::GetFixedDeviceShape(const ShapeVector &, const AnfNodePtr &node,
  369. size_t index, bool is_output) {
  370. if (node == nullptr || !node->isa<CNode>()) {
  371. return {};
  372. }
  373. auto attr_name = is_output ? kAttrFixedOutputDeviceShape : kAttrFixedInputDeviceShape;
  374. auto cnode = node->cast<CNodePtr>();
  375. if (!AnfAlgo::HasNodeAttr(attr_name, cnode)) {
  376. return {};
  377. }
  378. auto shapes = AnfAlgo::GetNodeAttr<std::vector<ShapeVector>>(cnode, attr_name);
  379. if (index >= shapes.size()) {
  380. MS_LOG(INFO) << "Index is out of range, got index: " << index << ", shape size: " << shapes.size();
  381. return {};
  382. }
  383. return std::optional<ShapeVector>(std::move(shapes[index]));
  384. }
  385. ShapeVector DeviceShapeTransfer::TransCore(const ShapeVector &shape, const std::string &format, const TypeId &type,
  386. int64_t groups, const ShapeVector &input_hidden_size) {
  387. using DeviceShapeTransfer = std::function<ShapeVector(const ShapeVector &, const TypeId &)>;
  388. const std::map<std::string, DeviceShapeTransfer> device_shape_map = {
  389. {kOpFormat_NCHW, NCHWDeviceShape},
  390. {kOpFormat_NHWC, NHWCDeviceShape},
  391. {kOpFormat_HWCN, HWCNDeviceShape},
  392. {kOpFormat_NCDHW, NCDHWDeviceShape},
  393. {kOpFormat_FRAC_Z, FRAC_ZDeviceShape},
  394. {kOpFormat_FRAC_NZ, FRAC_NZDeviceShape},
  395. {kOpFormat_NC1HWC0, NC1HWC0DeviceShape},
  396. {kOpFormat_NDC1HWC0, NDC1HWC0DeviceShape},
  397. {kOpFormat_C1HWNCoC0, C1HWNCOC0DeviceShape},
  398. {kOpFormat_NC1HWC0_C04, NC1HWC04DeviceShape},
  399. {kOpFormat_FRACTAL_Z_3D, FRAC_Z3DDeviceShape},
  400. {kOpFormat_FRACTAL_Z_C04, FRAC_ZC04DeviceShape},
  401. {kOpFormat_ChannelLast, ChannelLastDeviceShape},
  402. {kOpFormat_FRACTAL_ZN_LSTM, FRAC_ZN_LSTMDeviceShape}};
  403. if (format == kOpFormat_ND || format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
  404. return shape;
  405. }
  406. if (groups > 1 && format == kOpFormat_FRAC_Z) {
  407. return FRAC_ZDeviceShapeWithGroups(shape, type, groups);
  408. }
  409. if (format == kOpFormat_FRACTAL_ZN_RNN) {
  410. return FRAC_ZN_RNNDeviceShape(shape, type, input_hidden_size);
  411. }
  412. if (format == kOpFormat_ND_RNN_BIAS) {
  413. return NDRNNBiasDeviceShape(shape, type, input_hidden_size[1]);
  414. }
  415. auto temp_shape = shape;
  416. if (kNoPaddingFormatSet.find(format) == kNoPaddingFormatSet.end() && format != kOpFormat_FRACTAL_ZN_LSTM &&
  417. shape.size() < kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
  418. MS_LOG(WARNING) << "Origin shape size is less than 4, should be Padding shape by Default firstly";
  419. temp_shape = PaddingShapeTo4dDefault(shape);
  420. }
  421. if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
  422. temp_shape = PaddingShapeTo5dDefault(shape);
  423. }
  424. auto iter = device_shape_map.find(format);
  425. if (iter == device_shape_map.end()) {
  426. MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
  427. }
  428. return iter->second(temp_shape, type);
  429. }
  430. ShapeVector DeviceShapeTransfer::NCHWDeviceShape(const ShapeVector &shape, const TypeId &) {
  431. if (!CheckDims(shape)) {
  432. MS_LOG(EXCEPTION) << "Check dims failed.";
  433. }
  434. return shape;
  435. }
  436. ShapeVector DeviceShapeTransfer::NHWCDeviceShape(const ShapeVector &shape, const TypeId &) {
  437. if (!CheckDims(shape)) {
  438. MS_LOG(EXCEPTION) << "Check dims failed.";
  439. }
  440. ShapeVector device_shape;
  441. device_shape.push_back(shape[kN]);
  442. device_shape.push_back(shape[kH]);
  443. device_shape.push_back(shape[kW]);
  444. device_shape.push_back(shape[kC]);
  445. return device_shape;
  446. }
  447. ShapeVector DeviceShapeTransfer::HWCNDeviceShape(const ShapeVector &shape, const TypeId &) {
  448. if (!CheckDims(shape)) {
  449. MS_LOG(EXCEPTION) << "Check dims failed.";
  450. }
  451. ShapeVector device_shape;
  452. device_shape.push_back(shape[kH]);
  453. device_shape.push_back(shape[kW]);
  454. device_shape.push_back(shape[kC]);
  455. device_shape.push_back(shape[kN]);
  456. return device_shape;
  457. }
  458. ShapeVector DeviceShapeTransfer::FRAC_ZDeviceShape(const ShapeVector &shape, const TypeId &type) {
  459. if (!CheckDims(shape)) {
  460. MS_LOG(EXCEPTION) << "Check dims failed.";
  461. }
  462. ShapeVector device_shape;
  463. auto c0 = GetCubeSizeByType(type);
  464. if (HasShapeDynamic({shape[kC], shape[kH], shape[kW]})) {
  465. device_shape.push_back(abstract::Shape::SHP_ANY);
  466. } else {
  467. auto c1 = (shape[kC] + c0 - 1) / c0;
  468. device_shape.push_back(shape[kH] * shape[kW] * c1);
  469. }
  470. if (shape[kN] == abstract::Shape::SHP_ANY) {
  471. device_shape.push_back(abstract::Shape::SHP_ANY);
  472. } else {
  473. auto no = (shape[kN] + kNiSize - 1) / kNiSize;
  474. device_shape.push_back(no);
  475. }
  476. device_shape.push_back(kNiSize);
  477. device_shape.push_back(c0);
  478. return device_shape;
  479. }
  480. ShapeVector DeviceShapeTransfer::NC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
  481. if (!CheckDims(shape)) {
  482. MS_LOG(EXCEPTION) << "Check dims failed.";
  483. }
  484. ShapeVector device_shape;
  485. auto c0 = GetCubeSizeByType(type);
  486. auto c1 = (shape[kC] == abstract::Shape::SHP_ANY) ? abstract::Shape::SHP_ANY : (shape[kC] + c0 - 1) / c0;
  487. device_shape.push_back(shape[kN]);
  488. device_shape.push_back(c1);
  489. device_shape.push_back(shape[kH]);
  490. device_shape.push_back(shape[kW]);
  491. device_shape.push_back(c0);
  492. return device_shape;
  493. }
  494. ShapeVector DeviceShapeTransfer::NDC1HWC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
  495. if (shape.size() != kNcdhw) {
  496. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  497. }
  498. ShapeVector device_shape;
  499. auto c0 = GetCubeSizeByType(type);
  500. auto c1 = (shape[1] == abstract::Shape::SHP_ANY) ? abstract::Shape::SHP_ANY : (shape[1] + c0 - 1) / c0;
  501. device_shape.push_back(shape[N_ncdhw]);
  502. device_shape.push_back(shape[D_ncdhw]);
  503. device_shape.push_back(c1);
  504. device_shape.push_back(shape[H_ncdhw]);
  505. device_shape.push_back(shape[W_ncdhw]);
  506. device_shape.push_back(c0);
  507. return device_shape;
  508. }
  509. ShapeVector DeviceShapeTransfer::FRAC_Z3DDeviceShape(const ShapeVector &shape, const TypeId &type) {
  510. if (shape.size() != kNcdhw) {
  511. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  512. }
  513. ShapeVector device_shape;
  514. auto c0 = GetCubeSizeByType(type);
  515. if (HasShapeDynamic({shape[C_ncdhw], shape[D_ncdhw], shape[H_ncdhw], shape[W_ncdhw]})) {
  516. device_shape.push_back(abstract::Shape::SHP_ANY);
  517. } else {
  518. auto c1 = (shape[1] + c0 - 1) / c0;
  519. device_shape.push_back(shape[D_ncdhw] * c1 * shape[H_ncdhw] * shape[W_ncdhw]);
  520. }
  521. auto no = (shape[0] == abstract::Shape::SHP_ANY) ? abstract::Shape::SHP_ANY : (shape[0] + kNiSize - 1) / kNiSize;
  522. device_shape.push_back(no);
  523. device_shape.push_back(kNiSize);
  524. device_shape.push_back(c0);
  525. return device_shape;
  526. }
  527. ShapeVector DeviceShapeTransfer::C1HWNCOC0DeviceShape(const ShapeVector &shape, const TypeId &type) {
  528. if (!CheckDims(shape)) {
  529. MS_LOG(EXCEPTION) << "Check dims failed.";
  530. }
  531. ShapeVector device_shape;
  532. auto c0 = GetCubeSizeByType(type);
  533. if (shape[kC] == abstract::Shape::SHP_ANY) {
  534. device_shape.push_back(abstract::Shape::SHP_ANY);
  535. } else {
  536. device_shape.push_back((shape[kC] - 1) / c0 + 1);
  537. }
  538. device_shape.push_back(shape[kH]);
  539. device_shape.push_back(shape[kW]);
  540. device_shape.push_back(shape[kN]);
  541. device_shape.push_back(c0);
  542. device_shape.push_back(c0);
  543. return device_shape;
  544. }
  545. ShapeVector DeviceShapeTransfer::FRAC_ZC04DeviceShape(const ShapeVector &shape, const TypeId &type) {
  546. if (!CheckDims(shape)) {
  547. MS_LOG(EXCEPTION) << "Check dims failed.";
  548. }
  549. ShapeVector device_shape;
  550. const int64_t C04 = 4;
  551. int64_t first_dim;
  552. if (HasShapeDynamic({shape[kH], shape[kW]})) {
  553. first_dim = abstract::Shape::SHP_ANY;
  554. } else {
  555. first_dim = DivCeil(C04 * shape[kH] * shape[kW], kCubeSize);
  556. }
  557. auto no = (shape[kN] == abstract::Shape::SHP_ANY) ? abstract::Shape::SHP_ANY : DivCeil(shape.at(kN), kCubeSize);
  558. device_shape.push_back(first_dim);
  559. device_shape.push_back(no);
  560. device_shape.push_back(kCubeSize);
  561. device_shape.push_back(kCubeSize);
  562. return device_shape;
  563. }
  564. ShapeVector DeviceShapeTransfer::NC1HWC04DeviceShape(const ShapeVector &shape, const TypeId &) {
  565. if (!CheckDims(shape)) {
  566. MS_LOG(EXCEPTION) << "Check dims failed.";
  567. }
  568. ShapeVector device_shape;
  569. const int64_t C04 = 4;
  570. const int64_t C1 = (shape[kC] == abstract::Shape::SHP_ANY) ? abstract::Shape::SHP_ANY : DivCeil(shape.at(kC), C04);
  571. device_shape.push_back(shape[kN]);
  572. device_shape.push_back(C1);
  573. device_shape.push_back(shape[kH]);
  574. device_shape.push_back(shape[kW]);
  575. device_shape.push_back(C04);
  576. return device_shape;
  577. }
  578. ShapeVector DeviceShapeTransfer::NCDHWDeviceShape(const ShapeVector &shape, const TypeId &) {
  579. if (shape.size() < kNcdhw) {
  580. MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
  581. }
  582. return shape;
  583. }
  584. ShapeVector DeviceShapeTransfer::ChannelLastDeviceShape(const ShapeVector &shape, const TypeId &) {
  585. auto dim = shape.size();
  586. ShapeVector axis;
  587. axis.resize(dim);
  588. const int step_value = 2;
  589. std::iota(axis.begin() + 1, axis.end(), step_value);
  590. axis[dim - 1] = 1;
  591. ShapeVector device_shape;
  592. (void)std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape),
  593. [&shape](int64_t n) { return shape[n]; });
  594. return device_shape;
  595. }
  596. ShapeVector DeviceShapeTransfer::FRAC_NZDeviceShape(const ShapeVector &shape, const TypeId &type) {
  597. ShapeVector device_shape;
  598. auto c0 = GetCubeSizeByType(type);
  599. if (shape.size() == 1 && (shape[0] == 1 || shape[0] % c0 == 0)) {
  600. // For [1] and [1024] shape we can trait it as NZ shape
  601. return shape;
  602. }
  603. if (shape.size() < kShape2dDims) {
  604. MS_LOG(EXCEPTION) << "Format FRACTAL_NZ don't support shape with " << shape.size() << " dims";
  605. } else {
  606. const auto remove_dim = 2;
  607. (void)std::copy(shape.begin(), shape.end() - remove_dim, std::back_inserter(device_shape));
  608. }
  609. int64_t h_shape = shape[shape.size() - kH];
  610. int64_t w_shape = shape[shape.size() - 1];
  611. int64_t w1 = (w_shape == abstract::Shape::SHP_ANY) ? abstract::Shape::SHP_ANY : (w_shape - 1) / c0 + 1;
  612. int64_t h1 = (h_shape == abstract::Shape::SHP_ANY) ? abstract::Shape::SHP_ANY : (h_shape - 1) / kCubeSize + 1;
  613. device_shape.push_back(w1);
  614. device_shape.push_back(h1);
  615. device_shape.push_back(kCubeSize);
  616. device_shape.push_back(c0);
  617. return device_shape;
  618. }
  619. ShapeVector DeviceShapeTransfer::FRAC_ZN_LSTMDeviceShape(const ShapeVector &shape, const TypeId &type) {
  620. ShapeVector device_shape;
  621. const int64_t lstm_ni = 4;
  622. const int64_t ni = 16;
  623. int64_t first = abstract::Shape::SHP_ANY;
  624. int64_t second = abstract::Shape::SHP_ANY;
  625. if (!HasShapeDynamic({shape[kN], shape[kC]})) {
  626. const int64_t h = shape.at(kN) / lstm_ni;
  627. const int64_t i = shape.at(kC) - h;
  628. first = DivCeil(i, ni) + DivCeil(h, ni);
  629. second = lstm_ni * DivCeil(h, ni);
  630. }
  631. device_shape.push_back(first);
  632. device_shape.push_back(second);
  633. device_shape.push_back(ni);
  634. device_shape.push_back(ni);
  635. return device_shape;
  636. }
  637. ShapeVector DeviceShapeTransfer::FRAC_ZDeviceShapeWithGroups(const ShapeVector &shape, const TypeId &type,
  638. int64_t groups) {
  639. if (!CheckDims(shape)) {
  640. MS_LOG(EXCEPTION) << "Check dims failed.";
  641. }
  642. if (groups <= 0) {
  643. MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
  644. }
  645. auto cube_size = GetCubeSizeByType(type);
  646. auto c1_dim = abstract::Shape::SHP_ANY;
  647. auto g_dim = abstract::Shape::SHP_ANY;
  648. auto n1 = abstract::Shape::SHP_ANY;
  649. if (!HasShapeDynamic({shape[kC], shape[kN]})) {
  650. auto group_size = groups;
  651. auto cin_ori_tmp = static_cast<int64_t>(shape[kC]);
  652. auto cout_ori_tmp = static_cast<int64_t>(shape[kN]) / group_size;
  653. auto e_mult =
  654. std::min(Lcm(Lcm(cin_ori_tmp, cube_size) / cin_ori_tmp, Lcm(cout_ori_tmp, cube_size) / cout_ori_tmp), group_size);
  655. auto cin_opt = DivCeil(e_mult * cin_ori_tmp, cube_size) * cube_size;
  656. c1_dim = cin_opt / cube_size;
  657. g_dim = DivCeil(group_size, e_mult);
  658. n1 = DivCeil(cout_ori_tmp * e_mult, cube_size);
  659. }
  660. ShapeVector device_shape;
  661. if (!HasShapeDynamic({shape[kC], shape[kN], shape[kH], shape[kW]})) {
  662. device_shape.push_back(g_dim * c1_dim * shape[kH] * shape[kW]);
  663. } else {
  664. device_shape.push_back(abstract::Shape::SHP_ANY);
  665. }
  666. device_shape.push_back(n1);
  667. device_shape.push_back(kNiSize);
  668. device_shape.push_back(cube_size);
  669. return device_shape;
  670. }
  671. ShapeVector DeviceShapeTransfer::FRAC_ZN_RNNDeviceShape(const ShapeVector &shape, const TypeId &type,
  672. const ShapeVector &input_hidden_size) {
  673. if (shape.size() < kShape2dDims) {
  674. MS_LOG(EXCEPTION) << "Format FRACTAL_NZ_RNN don't support shape with " << shape.size() << " dims";
  675. }
  676. auto C0 = GetCubeSizeByType(type);
  677. auto input_size = input_hidden_size[0];
  678. auto hidden_size = input_hidden_size[1];
  679. auto dim_last1 = shape[shape.size() - 1];
  680. auto dim_last2 = shape[shape.size() - kDim2];
  681. const int64_t NUM16 = 16;
  682. ShapeVector device_shape = shape;
  683. if (dim_last2 == abstract::Shape::SHP_ANY) {
  684. device_shape[shape.size() - kDim2] = abstract::Shape::SHP_ANY;
  685. } else if (dim_last2 == input_size || dim_last2 == hidden_size) {
  686. device_shape[shape.size() - kDim2] = DivCeil(dim_last2, NUM16);
  687. } else if (dim_last2 == input_size + hidden_size) {
  688. device_shape[shape.size() - kDim2] = DivCeil(input_size, NUM16) + DivCeil(hidden_size, NUM16);
  689. } else {
  690. MS_LOG(EXCEPTION) << "The second-last dim value of shape is invalid.";
  691. }
  692. if (dim_last1 == abstract::Shape::SHP_ANY) {
  693. device_shape[shape.size() - kDim1] = abstract::Shape::SHP_ANY;
  694. } else {
  695. if (dim_last1 % hidden_size != 0) {
  696. MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
  697. }
  698. int64_t n_num = shape[shape.size() - 1] / hidden_size;
  699. device_shape[shape.size() - kDim1] = n_num * DivCeil(hidden_size, C0);
  700. }
  701. device_shape.push_back(NUM16);
  702. device_shape.push_back(C0);
  703. return device_shape;
  704. }
  705. ShapeVector DeviceShapeTransfer::NDRNNBiasDeviceShape(const ShapeVector &shape, const TypeId &type,
  706. int64_t hidden_size) {
  707. if (shape.empty()) {
  708. MS_LOG(EXCEPTION) << "Format ND_RNN_BIAS don't support empty shape.";
  709. }
  710. auto C0 = GetCubeSizeByType(type);
  711. ShapeVector device_shape = shape;
  712. // cppcheck-suppress *
  713. auto dim_last1 = shape[shape.size() - 1];
  714. if (dim_last1 == abstract::Shape::SHP_ANY) {
  715. device_shape[shape.size() - 1] = abstract::Shape::SHP_ANY;
  716. } else {
  717. if (hidden_size <= 0 || dim_last1 % hidden_size != 0) {
  718. MS_LOG(EXCEPTION) << "Last dim of shape " << shape << " should be multiple of hidden_size " << hidden_size;
  719. }
  720. int64_t n_num = shape[shape.size() - 1] / hidden_size;
  721. device_shape[shape.size() - 1] = n_num * DivCeil(hidden_size, C0) * C0;
  722. }
  723. return device_shape;
  724. }
  725. ShapeVector DeviceShapeTransfer::GetAttrInputAndHiddenSize(const AnfNodePtr &node) {
  726. MS_EXCEPTION_IF_NULL(node);
  727. std::vector<int64_t> input_hidden_size = {kAlign16, kAlign16};
  728. if (!node->isa<CNode>() && !node->isa<Parameter>()) {
  729. return input_hidden_size;
  730. }
  731. if (node->isa<Parameter>()) {
  732. auto param = node->cast<ParameterPtr>();
  733. input_hidden_size[0] = param->input_size();
  734. input_hidden_size[1] = param->hidden_size();
  735. } else {
  736. CNodePtr cnode = node->cast<CNodePtr>();
  737. if (cnode == nullptr || !AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode) ||
  738. !AnfAlgo::HasNodeAttr(kAttrInputSize, cnode)) {
  739. MS_LOG(EXCEPTION)
  740. << "Node with format FRACTAL_ZN_RNN or ND_RNN_BIAS should have hidden_size or input_size attr. Node info:"
  741. << node->DebugString();
  742. }
  743. input_hidden_size[0] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrInputSize);
  744. input_hidden_size[1] = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrHiddenSize);
  745. }
  746. return input_hidden_size;
  747. }
  748. /**###################### DATA FORMAT TRANS ################################*/
  749. inline void SetData(int64_t size, bool pad_zero, int64_t src_idx, int64_t dst_idx, const FormatArgs &args,
  750. void *result) {
  751. switch (size) {
  752. case b1:
  753. static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
  754. break;
  755. case b2:
  756. static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
  757. break;
  758. case b4:
  759. static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
  760. break;
  761. case b8:
  762. static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
  763. break;
  764. default:
  765. MS_LOG(EXCEPTION) << "Trans data not support size " << size;
  766. }
  767. }
  768. bool FormatTransfer::TransDataByFormat(const FormatArgs &args, void *result, const AnfNodePtr &node, size_t index,
  769. bool is_forward) {
  770. int64_t groups = 1;
  771. if (args.device_format == kOpFormat_FRAC_Z && node != nullptr) {
  772. groups = AnfAlgo::GetAttrGroups(node, index);
  773. }
  774. if (is_forward) {
  775. return TransDataForwardCore(args, result, groups);
  776. }
  777. return TransDataBackwordCore(args, result, groups);
  778. }
  779. bool FormatTransfer::TransDataForwardCore(const FormatArgs &args, void *result, int64_t groups) {
  780. MS_LOG(DEBUG) << "Start trans format.";
  781. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  782. MS_LOG(ERROR) << "Invalid datatype: " << args.src_data_type;
  783. return false;
  784. }
  785. if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
  786. return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, true, groups);
  787. }
  788. auto iter = format_trans_fp_map.find(args.device_format);
  789. if (iter == format_trans_fp_map.end()) {
  790. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  791. }
  792. return iter->second(args, result);
  793. }
  794. bool FormatTransfer::TransDataBackwordCore(const FormatArgs &args, void *result, int64_t groups) {
  795. MS_LOG(DEBUG) << "Start trans format.";
  796. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  797. MS_LOG(ERROR) << "Invalid datatype, type: " << args.src_data_type;
  798. return false;
  799. }
  800. if (groups > 1 && args.device_format == kOpFormat_FRAC_Z) {
  801. return FRAC_Z_TO_NCHW_WITH_GROUPS(args, result, groups);
  802. }
  803. auto iter = format_trans_bp_map.find(args.device_format);
  804. if (iter == format_trans_bp_map.end()) {
  805. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  806. }
  807. return iter->second(args, result);
  808. }
  809. bool FormatTransfer::CheckArgs(const FormatArgs &args, int64_t *size) {
  810. if (args.host_shape.size() != kNchwDims) {
  811. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  812. return false;
  813. }
  814. MS_EXCEPTION_IF_NULL(size);
  815. *size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  816. if (*size < 1) {
  817. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  818. return false;
  819. }
  820. auto total_size = abstract::ShapeSize(args.device_shape) * (*size);
  821. if (total_size != SizeToLong(args.device_size)) {
  822. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  823. return false;
  824. }
  825. return true;
  826. }
  827. bool FormatTransfer::TransShapeToHW_NZ(const ShapeVector &host_shape, ShapeVector *hw_shape) {
  828. MS_EXCEPTION_IF_NULL(hw_shape);
  829. if (host_shape.empty()) {
  830. MS_LOG(ERROR) << "Size of vector is 0.";
  831. return false;
  832. }
  833. switch (host_shape.size()) {
  834. case 1:
  835. hw_shape->push_back(1);
  836. hw_shape->push_back(1);
  837. hw_shape->push_back(host_shape[0]);
  838. return true;
  839. default:
  840. auto size = host_shape.size();
  841. if (size < kDim2) {
  842. MS_LOG(ERROR) << "Illegal size: " << size;
  843. return false;
  844. }
  845. int64_t times = 1;
  846. for (size_t i = 0; i != size - kDim2; i++) {
  847. times *= host_shape[i];
  848. }
  849. hw_shape->push_back(times);
  850. hw_shape->push_back(host_shape[size - kDim2]);
  851. hw_shape->push_back(host_shape[size - kDim1]);
  852. return true;
  853. }
  854. }
  855. bool FormatTransfer::NCHW_TO_4D(const FormatArgs &args, void *result) {
  856. // trans nchw to NHWC or HWCN
  857. MS_LOG(DEBUG) << "Trans format from nchw to " << args.device_format;
  858. MS_EXCEPTION_IF_NULL(result);
  859. int64_t size = 0;
  860. if (!CheckArgs(args, &size)) {
  861. MS_LOG(ERROR) << "Check args failed.";
  862. return false;
  863. }
  864. auto n = args.host_shape[kN];
  865. auto c = args.host_shape[kC];
  866. auto h = args.host_shape[kH];
  867. auto w = args.host_shape[kW];
  868. for (int64_t ni = 0; ni < n; ni++) {
  869. for (int64_t ci = 0; ci < c; ci++) {
  870. for (int64_t hi = 0; hi < h; hi++) {
  871. for (int64_t wi = 0; wi < w; wi++) {
  872. auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  873. int64_t dst_idx = 0;
  874. if (args.device_format == kOpFormat_NHWC) {
  875. dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  876. } else if (args.device_format == kOpFormat_HWCN) {
  877. dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  878. }
  879. SetData(size, false, src_idx, dst_idx, args, result);
  880. }
  881. }
  882. }
  883. }
  884. return true;
  885. }
  886. bool FormatTransfer::TO_NCHW(const FormatArgs &args, void *result) {
  887. MS_LOG(DEBUG) << "Trans format to nchw from " << args.device_format;
  888. MS_EXCEPTION_IF_NULL(result);
  889. int64_t size = 0;
  890. if (!CheckArgs(args, &size)) {
  891. MS_LOG(ERROR) << "Check args failed.";
  892. return false;
  893. }
  894. auto n = args.host_shape[kN];
  895. auto c = args.host_shape[kC];
  896. auto h = args.host_shape[kH];
  897. auto w = args.host_shape[kW];
  898. for (int64_t ni = 0; ni < n; ni++) {
  899. for (int64_t ci = 0; ci < c; ci++) {
  900. for (int64_t hi = 0; hi < h; hi++) {
  901. for (int64_t wi = 0; wi < w; wi++) {
  902. auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  903. int64_t src_idx = 0;
  904. if (args.device_format == kOpFormat_NHWC) {
  905. src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  906. } else if (args.device_format == kOpFormat_HWCN) {
  907. src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  908. }
  909. SetData(size, false, src_idx, dst_idx, args, result);
  910. }
  911. }
  912. }
  913. }
  914. return true;
  915. }
  916. bool FormatTransfer::NCHW_TO_FRAC_Z(const FormatArgs &args, void *result) {
  917. MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
  918. MS_EXCEPTION_IF_NULL(result);
  919. if (args.host_shape.size() != kNchwDims) {
  920. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  921. return false;
  922. }
  923. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  924. if (size < 1) {
  925. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  926. return false;
  927. }
  928. auto n = args.host_shape[kN];
  929. auto c = args.host_shape[kC];
  930. auto h = args.host_shape[kH];
  931. auto w = args.host_shape[kW];
  932. auto c0 = GetCubeSizeByType(args.src_data_type);
  933. auto c1 = DivCeil(c, c0);
  934. auto hw = h * w;
  935. auto chw = c * hw;
  936. auto hwc0 = hw * c0;
  937. auto nchw = n * chw;
  938. auto hf_cnt = DivCeil(n, kNiSize);
  939. auto vf_cnt = c1 * hw;
  940. auto fractal_ele_cnt = c0 * kNiSize;
  941. auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
  942. auto dst_size = total_ele_cnt * size;
  943. if (dst_size != SizeToLong(args.device_size)) {
  944. MS_LOG(ERROR) << "Illegal total data size."
  945. << "dst size is :" << dst_size << "device size is :" << args.device_size;
  946. return false;
  947. }
  948. for (int64_t vfi = 0; vfi < vf_cnt; vfi++) {
  949. auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
  950. for (int64_t hfi = 0; hfi < hf_cnt; hfi++) {
  951. auto gfi = vf_base_i + hfi; // global fractal matrix index
  952. auto src_n_offset = hfi * chw * kNiSize;
  953. auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
  954. for (int64_t row = 0; row < c0; row++) {
  955. auto src_ci = vfi / hw * c0 + row;
  956. auto src_row_offset = src_f_offset + row * hw;
  957. for (int64_t col = 0; col < kNiSize; col++) {
  958. auto src_ni = hfi * kNiSize + col;
  959. auto src_idx = src_row_offset + chw * col;
  960. auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
  961. auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
  962. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  963. }
  964. }
  965. }
  966. }
  967. return true;
  968. }
  969. bool FormatTransfer::NCHW_TO_FRAC_NZ(const FormatArgs &args, void *result) {
  970. MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
  971. MS_EXCEPTION_IF_NULL(result);
  972. ShapeVector hw_shape;
  973. if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
  974. MS_LOG(ERROR) << "Trans shape failed..";
  975. return false;
  976. }
  977. if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
  978. MS_LOG(ERROR) << "Invalid shape size.";
  979. return false;
  980. }
  981. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  982. if (size < 1) {
  983. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  984. return false;
  985. }
  986. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  987. if (dst_size != SizeToLong(args.device_size)) {
  988. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  989. return false;
  990. }
  991. auto times = hw_shape.at(0);
  992. auto h = hw_shape.at(1);
  993. auto w = hw_shape.at(2);
  994. auto hw = h * w;
  995. auto shape_size = args.device_shape.size();
  996. auto w1 = args.device_shape[shape_size - 4];
  997. auto h1 = args.device_shape[shape_size - 3];
  998. auto h0 = args.device_shape[shape_size - 2];
  999. auto w0 = args.device_shape[shape_size - 1];
  1000. auto h1h0w0 = h1 * h0 * w0;
  1001. auto w1h1h0w0 = w1 * h1h0w0;
  1002. auto num_w1 = w / w0;
  1003. for (int64_t times_idx = 0; times_idx < times; times_idx++) {
  1004. auto times_head = times_idx * w1h1h0w0;
  1005. auto src_times_head = times_idx * hw;
  1006. for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  1007. auto h1h0_head = times_head + h1h0_idx * w0;
  1008. auto src_h_head = src_times_head + h1h0_idx * w;
  1009. for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  1010. for (int64_t i = 0; i < w0; ++i) {
  1011. int64_t src_idx = src_h_head + w1_idx * w0 + i;
  1012. int64_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
  1013. SetData(size, false, src_idx, dst_idx, args, result);
  1014. }
  1015. }
  1016. auto w1_head = num_w1 * w0;
  1017. for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  1018. auto src_w_idx = w1_head + w0_idx;
  1019. int64_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  1020. int64_t src_idx = src_h_head + src_w_idx;
  1021. SetData(size, false, src_idx, dst_idx, args, result);
  1022. }
  1023. }
  1024. }
  1025. return true;
  1026. }
  1027. bool FormatTransfer::NCHW_TO_FRAC_ZC04(const FormatArgs &args, void *result) {
  1028. // trans nchw to FracZc04
  1029. MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
  1030. MS_EXCEPTION_IF_NULL(result);
  1031. int64_t size = 0;
  1032. if (!CheckArgs(args, &size)) {
  1033. MS_LOG(ERROR) << "Check args failed.";
  1034. return false;
  1035. }
  1036. auto cube = GetCubeSizeByType(args.src_data_type);
  1037. auto n = args.host_shape[kN];
  1038. auto c = args.host_shape[kC];
  1039. auto h = args.host_shape[kH];
  1040. auto w = args.host_shape[kW];
  1041. const int64_t c0 = 4;
  1042. auto c1 = DivCeil(c, c0);
  1043. auto hwc0 = h * w * c0;
  1044. auto hwc = h * w * c;
  1045. auto nhwc = n * h * w * c;
  1046. auto n_cnt = DivCeil(n, kNiSize);
  1047. auto v_cnt = DivCeil(h * w * c0 * c1, cube);
  1048. int64_t dst_idx = 0;
  1049. for (int64_t vi = 0; vi < v_cnt; vi++) {
  1050. for (int64_t ni = 0; ni < n_cnt; ni++) {
  1051. for (int64_t col = 0; col < kNiSize; col++) {
  1052. for (int64_t row = 0; row < kNiSize; row++) {
  1053. int64_t cur_cube_n = kNiSize * ni + col;
  1054. int64_t cur_cube_c1hwc0 = kNiSize * vi + row;
  1055. auto desc_g = cur_cube_n / n;
  1056. auto desc_n = cur_cube_n % n;
  1057. auto desc_c1 = cur_cube_c1hwc0 / hwc0;
  1058. auto desc_c0 = cur_cube_c1hwc0 % c0;
  1059. auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
  1060. auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
  1061. auto c_idx = desc_c1 * c0 + desc_c0;
  1062. auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
  1063. auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
  1064. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1065. dst_idx++;
  1066. }
  1067. }
  1068. }
  1069. }
  1070. return true;
  1071. }
  1072. bool FormatTransfer::NCHW_TO_NC1HWC0(const FormatArgs &args, void *result) {
  1073. MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
  1074. MS_EXCEPTION_IF_NULL(result);
  1075. if (args.host_shape.size() != kNchwDims) {
  1076. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1077. return false;
  1078. }
  1079. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1080. if (size < 1) {
  1081. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1082. return false;
  1083. }
  1084. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1085. if (total_size != SizeToLong(args.device_size)) {
  1086. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1087. return false;
  1088. }
  1089. auto n = args.host_shape[kN];
  1090. auto c = args.host_shape[kC];
  1091. auto h = args.host_shape[kH];
  1092. auto w = args.host_shape[kW];
  1093. auto c0 = GetCubeSizeByType(args.src_data_type);
  1094. if (args.device_format == kOpFormat_NC1HWC0_C04) {
  1095. c0 = kCubeSize_C04;
  1096. }
  1097. auto c1 = DivCeil(c, c0);
  1098. auto hw = h * w;
  1099. auto chw = c * hw;
  1100. auto c1hwc0 = c1 * hw * c0;
  1101. auto wc0 = w * c0;
  1102. for (int64_t n_idx = 0; n_idx < n; n_idx++) {
  1103. int64_t n_head_addr = n_idx * c1hwc0;
  1104. for (int64_t c1_idx = 0; c1_idx < c1; c1_idx++) {
  1105. int64_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
  1106. for (int64_t h_idx = 0; h_idx < h; h_idx++) {
  1107. int64_t h_head_addr = c1_head_addr + h_idx * wc0;
  1108. for (int64_t w_idx = 0; w_idx < w; w_idx++) {
  1109. int64_t w_head_addr = h_head_addr + w_idx * c0;
  1110. for (int64_t c0_idx = 0; c0_idx < c0; c0_idx++) {
  1111. int64_t dst_idx = c0_idx + w_head_addr;
  1112. int64_t c_idx = c0_idx + c1_idx * c0;
  1113. int64_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
  1114. auto pad_zero = c_idx >= c;
  1115. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1116. }
  1117. }
  1118. }
  1119. }
  1120. }
  1121. return true;
  1122. }
  1123. bool FormatTransfer::NCHW_TO_NC1HWC04(const FormatArgs &args, void *result) {
  1124. MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
  1125. return NCHW_TO_NC1HWC0(args, result);
  1126. }
  1127. bool FormatTransfer::NCHW_TO_C1HWNCOC0(const FormatArgs &args, void *result) {
  1128. // trans nchw to c1hwncoc0
  1129. MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
  1130. MS_EXCEPTION_IF_NULL(result);
  1131. int64_t size = 0;
  1132. if (!CheckArgs(args, &size)) {
  1133. MS_LOG(ERROR) << "Check args failed.";
  1134. return false;
  1135. }
  1136. auto n = args.host_shape[kN];
  1137. auto c = args.host_shape[kC];
  1138. auto h = args.host_shape[kH];
  1139. auto w = args.host_shape[kW];
  1140. const int co_idx = 4;
  1141. const int c0_idx = 5;
  1142. auto c1 = args.device_shape[0];
  1143. auto co = args.device_shape[co_idx];
  1144. auto c0 = args.device_shape[c0_idx];
  1145. for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
  1146. for (int64_t h_i = 0; h_i < h; h_i++) {
  1147. for (int64_t w_i = 0; w_i < w; w_i++) {
  1148. for (int64_t n_i = 0; n_i < n; n_i++) {
  1149. for (int64_t co_i = 0; co_i < co; co_i++) {
  1150. for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
  1151. int64_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
  1152. co_i * c0 + c0_i;
  1153. int64_t c_i = c0_i + c1_i * c0;
  1154. int64_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1155. auto pad_zero = !(c_i < c && c0_i == co_i);
  1156. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1157. }
  1158. }
  1159. }
  1160. }
  1161. }
  1162. }
  1163. return true;
  1164. }
  1165. bool FormatTransfer::NCDHW_TO_NDC1HWC0(const FormatArgs &args, void *result) {
  1166. MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
  1167. MS_EXCEPTION_IF_NULL(result);
  1168. if (args.host_shape.size() != kNcdhw) {
  1169. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1170. return false;
  1171. }
  1172. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1173. if (size < 1) {
  1174. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1175. return false;
  1176. }
  1177. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1178. if (total_size != SizeToLong(args.device_size)) {
  1179. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1180. return false;
  1181. }
  1182. auto n = args.host_shape[N_ncdhw];
  1183. auto c = args.host_shape[C_ncdhw];
  1184. auto d = args.host_shape[D_ncdhw];
  1185. auto h = args.host_shape[H_ncdhw];
  1186. auto w = args.host_shape[W_ncdhw];
  1187. auto c0 = GetCubeSizeByType(args.src_data_type);
  1188. auto c1 = DivCeil(c, c0);
  1189. const int64_t cdhw = c * d * h * w;
  1190. const int64_t dhw = d * h * w;
  1191. const int64_t hw = h * w;
  1192. const int64_t dc1hwc0 = d * c1 * h * w * c0;
  1193. const int64_t c1hwc0 = c1 * h * w * c0;
  1194. const int64_t hwc0 = h * w * c0;
  1195. const int64_t wc0 = w * c0;
  1196. for (int64_t n_i = 0; n_i < n; n_i++) {
  1197. int64_t n_head = n_i * dc1hwc0;
  1198. for (int64_t d_i = 0; d_i < d; d_i++) {
  1199. int64_t d_head = n_head + d_i * c1hwc0;
  1200. for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
  1201. int64_t c1_head = d_head + c1_i * hwc0;
  1202. for (int64_t h_i = 0; h_i < h; h_i++) {
  1203. int64_t h_head = c1_head + h_i * wc0;
  1204. for (int64_t w_i = 0; w_i < w; w_i++) {
  1205. int64_t w_head = h_head + w_i * c0;
  1206. for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
  1207. int64_t dst_i = c0_i + w_head;
  1208. int64_t c_i = c0_i + c1_i * c0;
  1209. int64_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
  1210. auto pad_zero = c_i >= c;
  1211. SetData(size, pad_zero, src_i, dst_i, args, result);
  1212. }
  1213. }
  1214. }
  1215. }
  1216. }
  1217. }
  1218. return true;
  1219. }
  1220. bool FormatTransfer::NCDHW_TO_FRAC_Z3D(const FormatArgs &args, void *result) {
  1221. MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
  1222. MS_EXCEPTION_IF_NULL(result);
  1223. if (args.host_shape.size() != kNcdhw) {
  1224. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1225. return false;
  1226. }
  1227. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1228. if (size < 1) {
  1229. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1230. return false;
  1231. }
  1232. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1233. if (total_size != SizeToLong(args.device_size)) {
  1234. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1235. return false;
  1236. }
  1237. auto n = args.host_shape[N_ncdhw];
  1238. auto c = args.host_shape[C_ncdhw];
  1239. auto d = args.host_shape[D_ncdhw];
  1240. auto h = args.host_shape[H_ncdhw];
  1241. auto w = args.host_shape[W_ncdhw];
  1242. auto n1n0 = DivCeil(n, kNiSize) * kNiSize;
  1243. auto c0 = GetCubeSizeByType(args.src_data_type);
  1244. auto c1 = DivCeil(c, c0);
  1245. auto hw = h * w;
  1246. auto dhw = d * hw;
  1247. auto cdhw = c * dhw;
  1248. auto n1n0c0 = n1n0 * c0;
  1249. auto wn1n0c0 = w * n1n0c0;
  1250. auto hwn1n0c0 = h * wn1n0c0;
  1251. auto c1hwn1n0c0 = c1 * hwn1n0c0;
  1252. for (int64_t d_i = 0; d_i < d; d_i++) {
  1253. for (int64_t c1_i = 0; c1_i < c1; c1_i++) {
  1254. for (int64_t h_i = 0; h_i < h; h_i++) {
  1255. for (int64_t w_i = 0; w_i < w; w_i++) {
  1256. for (int64_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
  1257. for (int64_t c0_i = 0; c0_i < c0; c0_i++) {
  1258. auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
  1259. // ncdhw
  1260. int64_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
  1261. auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
  1262. SetData(size, pad_zero, src_i, dst_i, args, result);
  1263. }
  1264. }
  1265. }
  1266. }
  1267. }
  1268. }
  1269. return true;
  1270. }
  1271. bool FormatTransfer::NCHW_TO_FRAC_Z_WITH_GROPUS(const FormatArgs &args, void *result, bool to_device, int64_t groups) {
  1272. MS_EXCEPTION_IF_NULL(result);
  1273. if (args.host_shape.size() != kNchwDims) {
  1274. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1275. return false;
  1276. }
  1277. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1278. if (size < 1) {
  1279. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1280. return false;
  1281. }
  1282. auto n_dim = args.host_shape[kN];
  1283. auto c_dim = args.host_shape[kC];
  1284. auto h_dim = args.host_shape[kH];
  1285. auto w_dim = args.host_shape[kW];
  1286. auto d_dim = 1;
  1287. auto cin_ori = c_dim;
  1288. if (groups <= 0) {
  1289. MS_LOG(EXCEPTION) << "The value of groups should be greater than 0, but got " << groups;
  1290. }
  1291. // cppcheck-suppress *
  1292. auto cout_ori = n_dim / groups;
  1293. if (cin_ori == 0 || cout_ori == 0) {
  1294. MS_LOG(ERROR) << "cin_ori, cout_ori must not equal to 0";
  1295. return false;
  1296. }
  1297. auto cube_k = GetCubeSizeByType(args.src_data_type);
  1298. auto e_mult = std::min(Lcm(Lcm(cin_ori, cube_k) / cin_ori, Lcm(cout_ori, kCubeSize) / cout_ori), groups);
  1299. if (e_mult == 0) {
  1300. MS_LOG(EXCEPTION) << "The value of e_mult should be greater than 0, but got " << e_mult;
  1301. }
  1302. auto cin_opt = DivCeil(e_mult * cin_ori, cube_k) * cube_k;
  1303. auto cout_opt = DivCeil(e_mult * cout_ori, kCubeSize) * kCubeSize;
  1304. // cppcheck-suppress *
  1305. auto c1_dim = cin_opt / cube_k;
  1306. auto dst_size =
  1307. to_device ? abstract::ShapeSize(args.device_shape) * size : abstract::ShapeSize(args.host_shape) * size;
  1308. if (dst_size == 0) {
  1309. return true;
  1310. }
  1311. auto ret = memset_s(result, dst_size, 0, dst_size);
  1312. if (ret != EOK) {
  1313. MS_LOG(ERROR) << "memset failed";
  1314. return false;
  1315. }
  1316. for (int64_t g = 0; g < groups; ++g) {
  1317. for (int64_t d = 0; d < d_dim; ++d) {
  1318. for (int64_t c = 0; c < c_dim; ++c) {
  1319. for (int64_t h = 0; h < h_dim; ++h) {
  1320. for (int64_t w = 0; w < w_dim; ++w) {
  1321. for (int64_t n = 0; n < cout_ori; ++n) {
  1322. int64_t e_val = g % e_mult;
  1323. int64_t dst_ci = e_val * cin_ori + c;
  1324. int64_t dst_co = e_val * cout_ori + n;
  1325. int64_t src_co = g * cout_ori + n;
  1326. int64_t temporary = dst_ci % cube_k;
  1327. int64_t dev_idx = (g / e_mult) * d_dim * c1_dim * h_dim * w_dim * cout_opt * cube_k +
  1328. d * c1_dim * h_dim * w_dim * cout_opt * cube_k +
  1329. (dst_ci / cube_k) * h_dim * w_dim * cout_opt * cube_k + h * w_dim * cout_opt * cube_k +
  1330. w * cout_opt * cube_k + dst_co * cube_k + temporary;
  1331. int64_t hst_idx =
  1332. src_co * c_dim * d_dim * h_dim * w_dim + c * d_dim * h_dim * w_dim + d * h_dim * w_dim + h * w_dim + w;
  1333. if (to_device) {
  1334. SetData(size, false, hst_idx, dev_idx, args, result);
  1335. } else {
  1336. SetData(size, false, dev_idx, hst_idx, args, result);
  1337. }
  1338. }
  1339. }
  1340. }
  1341. }
  1342. }
  1343. }
  1344. return true;
  1345. }
  1346. bool FormatTransfer::NC1HWC0_TO_NCHW(const FormatArgs &args, void *result) {
  1347. MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
  1348. MS_EXCEPTION_IF_NULL(result);
  1349. if (args.host_shape.size() != kNchwDims) {
  1350. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1351. return false;
  1352. }
  1353. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1354. if (size < 1) {
  1355. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1356. return false;
  1357. }
  1358. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1359. if (total_size != SizeToLong(args.device_size)) {
  1360. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1361. return false;
  1362. }
  1363. auto n = args.host_shape[kN];
  1364. auto c = args.host_shape[kC];
  1365. auto h = args.host_shape[kH];
  1366. auto w = args.host_shape[kW];
  1367. auto c1 = args.device_shape[1];
  1368. auto c0 = args.device_shape[4];
  1369. auto hw = h * w;
  1370. auto chw = c * hw;
  1371. auto wc0 = w * c0;
  1372. auto hwc0 = h * wc0;
  1373. auto c1hwc0 = c1 * hwc0;
  1374. for (int64_t n_idx = 0; n_idx < n; n_idx++) {
  1375. int64_t n_head_addr = n_idx * chw;
  1376. for (int64_t c_idx = 0; c_idx < c; c_idx++) {
  1377. int64_t c_head_addr = n_head_addr + c_idx * hw;
  1378. for (int64_t h_idx = 0; h_idx < h; h_idx++) {
  1379. int64_t h_head_addr = c_head_addr + h_idx * w;
  1380. for (int64_t w_idx = 0; w_idx < w; w_idx++) {
  1381. int64_t dst_idx = h_head_addr + w_idx;
  1382. int64_t c1_idx = c_idx / c0;
  1383. int64_t c0_idx = c_idx % c0;
  1384. int64_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
  1385. SetData(size, false, src_idx, dst_idx, args, result);
  1386. }
  1387. }
  1388. }
  1389. }
  1390. return true;
  1391. }
  1392. bool FormatTransfer::NC1HWC04_TO_NCHW(const FormatArgs &args, void *result) {
  1393. MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
  1394. return NC1HWC0_TO_NCHW(args, result);
  1395. }
  1396. bool FormatTransfer::C1HWNCOC0_TO_NCHW(const FormatArgs &args, void *result) {
  1397. // trans c1hwncoc0 to nchw
  1398. MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
  1399. MS_EXCEPTION_IF_NULL(result);
  1400. int64_t size = 0;
  1401. if (!CheckArgs(args, &size)) {
  1402. MS_LOG(ERROR) << "Check args failed.";
  1403. return false;
  1404. }
  1405. auto n = args.host_shape[kN];
  1406. auto c = args.host_shape[kC];
  1407. auto h = args.host_shape[kH];
  1408. auto w = args.host_shape[kW];
  1409. const int co_idx = 4;
  1410. const int c0_idx = 5;
  1411. auto co = args.device_shape[co_idx];
  1412. auto c0 = args.device_shape[c0_idx];
  1413. auto cube_k = GetCubeSizeByType(args.src_data_type);
  1414. for (int64_t n_i = 0; n_i < n; n_i++) {
  1415. for (int64_t c_i = 0; c_i < c; c_i++) {
  1416. for (int64_t h_i = 0; h_i < h; h_i++) {
  1417. for (int64_t w_i = 0; w_i < w; w_i++) {
  1418. int64_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1419. int64_t c1_i = c_i / cube_k;
  1420. int64_t c0_i = c_i % cube_k;
  1421. int64_t co_i = c0_i;
  1422. int64_t src_idx =
  1423. c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
  1424. SetData(size, false, src_idx, dst_idx, args, result);
  1425. }
  1426. }
  1427. }
  1428. }
  1429. return true;
  1430. }
  1431. bool FormatTransfer::FRAC_Z_TO_NCHW(const FormatArgs &args, void *result) {
  1432. MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
  1433. MS_EXCEPTION_IF_NULL(result);
  1434. if (args.host_shape.size() != kNchwDims) {
  1435. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  1436. return false;
  1437. }
  1438. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1439. if (size < 1) {
  1440. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1441. return false;
  1442. }
  1443. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1444. if (total_size != SizeToLong(args.device_size)) {
  1445. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1446. return false;
  1447. }
  1448. auto n0 = args.device_shape.at(1);
  1449. auto ni = args.device_shape.at(2);
  1450. auto c0 = args.device_shape.at(3);
  1451. auto n = args.host_shape[kN];
  1452. auto c = args.host_shape[kC];
  1453. auto h = args.host_shape[kH];
  1454. auto w = args.host_shape[kW];
  1455. auto nc = ni * n0;
  1456. auto ncc0 = nc * c0;
  1457. auto wncc0 = w * ncc0;
  1458. auto hwncc0 = h * wncc0;
  1459. auto hw = h * w;
  1460. auto chw = c * hw;
  1461. for (int64_t n_idx = 0; n_idx < n; n_idx++) {
  1462. int64_t n_head_addr = n_idx * chw;
  1463. for (int64_t c_idx = 0; c_idx < c; c_idx++) {
  1464. int64_t c_head_addr = n_head_addr + c_idx * hw;
  1465. for (int64_t h_idx = 0; h_idx < h; h_idx++) {
  1466. int64_t h_head_addr = c_head_addr + h_idx * w;
  1467. for (int64_t w_idx = 0; w_idx < w; w_idx++) {
  1468. auto dst_idx = h_head_addr + w_idx;
  1469. auto c1_idx = c_idx / c0;
  1470. auto c0_idx = c_idx % c0;
  1471. auto nc_idx = n_idx;
  1472. auto src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
  1473. SetData(size, false, src_idx, dst_idx, args, result);
  1474. }
  1475. }
  1476. }
  1477. }
  1478. return true;
  1479. }
  1480. bool FormatTransfer::FRAC_NZ_TO_NCHW(const FormatArgs &args, void *result) {
  1481. MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
  1482. MS_EXCEPTION_IF_NULL(result);
  1483. ShapeVector hw_shape;
  1484. if (!TransShapeToHW_NZ(args.host_shape, &hw_shape)) {
  1485. MS_LOG(ERROR) << "Trans shape failed..";
  1486. return false;
  1487. }
  1488. if (hw_shape.size() < kDim3 || args.device_shape.size() < kDim4) {
  1489. MS_LOG(ERROR) << "Invalid shape size.";
  1490. return false;
  1491. }
  1492. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1493. if (size < 1) {
  1494. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1495. return false;
  1496. }
  1497. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  1498. if (dst_size != SizeToLong(args.device_size)) {
  1499. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  1500. return false;
  1501. }
  1502. auto times = hw_shape.at(0);
  1503. auto h = hw_shape.at(1);
  1504. auto w = hw_shape.at(2);
  1505. auto hw = h * w;
  1506. auto shape_size = args.device_shape.size();
  1507. auto w1 = args.device_shape[shape_size - 4];
  1508. auto h1 = args.device_shape[shape_size - 3];
  1509. auto h0 = args.device_shape[shape_size - 2];
  1510. auto w0 = args.device_shape[shape_size - 1];
  1511. auto h1h0w0 = h1 * h0 * w0;
  1512. auto w1h1h0w0 = w1 * h1h0w0;
  1513. auto num_w1 = w / w0;
  1514. for (int64_t times_idx = 0; times_idx < times; times_idx++) {
  1515. auto times_head = times_idx * w1h1h0w0;
  1516. auto src_times_head = times_idx * hw;
  1517. for (int64_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  1518. auto h1h0_head = times_head + h1h0_idx * w0;
  1519. auto src_h_head = src_times_head + h1h0_idx * w;
  1520. for (int64_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  1521. for (int64_t i = 0; i < w0; ++i) {
  1522. int64_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
  1523. int64_t dst_idx = src_h_head + w1_idx * w0 + i;
  1524. SetData(size, false, src_idx, dst_idx, args, result);
  1525. }
  1526. }
  1527. auto w1_head = num_w1 * w0;
  1528. for (int64_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  1529. auto src_w_idx = w1_head + w0_idx;
  1530. int64_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  1531. int64_t dst_idx = src_h_head + src_w_idx;
  1532. SetData(size, false, src_idx, dst_idx, args, result);
  1533. }
  1534. }
  1535. }
  1536. return true;
  1537. }
  1538. bool FormatTransfer::FRAC_Z3D_TO_NCDHW(const FormatArgs &args, void *result) {
  1539. MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
  1540. MS_EXCEPTION_IF_NULL(result);
  1541. if (args.host_shape.size() != kNcdhw) {
  1542. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1543. return false;
  1544. }
  1545. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1546. if (size < 1) {
  1547. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1548. return false;
  1549. }
  1550. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1551. if (total_size != SizeToLong(args.device_size)) {
  1552. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1553. return false;
  1554. }
  1555. auto n = args.host_shape[N_ncdhw];
  1556. auto c = args.host_shape[C_ncdhw];
  1557. auto d = args.host_shape[D_ncdhw];
  1558. auto h = args.host_shape[H_ncdhw];
  1559. auto w = args.host_shape[W_ncdhw];
  1560. const int kFZ3D_C0 = 3;
  1561. auto c0 = args.device_shape[kFZ3D_C0];
  1562. auto cube_k = GetCubeSizeByType(args.src_data_type);
  1563. auto c1 = DivCeil(c, cube_k);
  1564. auto n1n0 = DivCeil(n, kNiSize) * kNiSize;
  1565. auto n1n0c0 = n1n0 * c0;
  1566. auto wn1n0c0 = w * n1n0c0;
  1567. auto hwn1n0c0 = h * wn1n0c0;
  1568. auto c1hwn1n0c0 = c1 * hwn1n0c0;
  1569. auto hw = h * w;
  1570. auto dhw = d * hw;
  1571. auto cdhw = c * dhw;
  1572. for (int64_t n_i = 0; n_i < n; n_i++) {
  1573. int64_t n_head = n_i * cdhw;
  1574. for (int64_t c_i = 0; c_i < c; c_i++) {
  1575. int64_t c_head = n_head + c_i * dhw;
  1576. for (int64_t d_i = 0; d_i < d; d_i++) {
  1577. int64_t d_head = c_head + d_i * hw;
  1578. for (int64_t h_i = 0; h_i < h; h_i++) {
  1579. int64_t h_head = d_head + h_i * w;
  1580. for (int64_t w_i = 0; w_i < w; w_i++) {
  1581. int64_t dst_i = h_head + w_i;
  1582. int64_t c1_i = c_i / c0;
  1583. int64_t c0_i = c_i % c0;
  1584. int64_t nc_i = n_i;
  1585. int64_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
  1586. SetData(size, false, src_i, dst_i, args, result);
  1587. }
  1588. }
  1589. }
  1590. }
  1591. }
  1592. return true;
  1593. }
  1594. bool FormatTransfer::NDC1HWC0_TO_NCDHW(const FormatArgs &args, void *result) {
  1595. MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
  1596. MS_EXCEPTION_IF_NULL(result);
  1597. if (args.host_shape.size() != kNcdhw) {
  1598. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1599. return false;
  1600. }
  1601. auto size = SizeToLong(abstract::TypeIdSize(args.src_data_type));
  1602. if (size < 1) {
  1603. MS_LOG(ERROR) << "Illegal dtype: " << args.src_data_type;
  1604. return false;
  1605. }
  1606. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1607. if (total_size != SizeToLong(args.device_size)) {
  1608. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1609. return false;
  1610. }
  1611. auto n = args.host_shape[N_ncdhw];
  1612. auto c = args.host_shape[C_ncdhw];
  1613. auto d = args.host_shape[D_ncdhw];
  1614. auto h = args.host_shape[H_ncdhw];
  1615. auto w = args.host_shape[W_ncdhw];
  1616. auto c1 = args.device_shape[C1_ndc1hwc0];
  1617. auto c0 = args.device_shape[C0_ndc1hwc0];
  1618. const int64_t cdhw = c * d * h * w;
  1619. const int64_t dhw = d * h * w;
  1620. const int64_t hw = h * w;
  1621. const int64_t dc1hwc0 = d * c1 * h * w * c0;
  1622. const int64_t c1hwc0 = c1 * h * w * c0;
  1623. const int64_t hwc0 = h * w * c0;
  1624. const int64_t wc0 = w * c0;
  1625. for (int64_t n_i = 0; n_i < n; n_i++) {
  1626. int64_t n_head = n_i * cdhw;
  1627. for (int64_t c_i = 0; c_i < c; c_i++) {
  1628. int64_t c_head = n_head + c_i * dhw;
  1629. for (int64_t d_i = 0; d_i < d; d_i++) {
  1630. int64_t d_head = c_head + d_i * hw;
  1631. for (int64_t h_i = 0; h_i < h; h_i++) {
  1632. int64_t h_head = d_head + h_i * w;
  1633. for (int64_t w_i = 0; w_i < w; w_i++) {
  1634. int64_t dst_i = h_head + w_i;
  1635. int64_t c1_i = c_i / c0;
  1636. int64_t c0_i = c_i % c0;
  1637. auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
  1638. SetData(size, false, src_idx, dst_i, args, result);
  1639. }
  1640. }
  1641. }
  1642. }
  1643. }
  1644. return true;
  1645. }
  1646. bool FormatTransfer::FRAC_Z_TO_NCHW_WITH_GROUPS(const FormatArgs &args, void *result, int64_t groups) {
  1647. MS_LOG(DEBUG) << "Trans format from frac_z to nchw with groups=" << groups;
  1648. return NCHW_TO_FRAC_Z_WITH_GROPUS(args, result, false, groups);
  1649. }
  1650. // ######################## RANGE TRANS ########################
  1651. RangePair ShapeRangeTransfer::GetRealRange(const RangePair &ori_range, const std::string &format, const TypeId &type) {
  1652. const std::set<std::string> no_need_change = {kOpFormat_ND, kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NCDHW};
  1653. using RangeTransfer = std::function<RangePair(const RangePair &, const TypeId &)>;
  1654. const std::map<std::string, RangeTransfer> format_range_map = {{kOpFormat_NHWC, NHWCRange},
  1655. {kOpFormat_HWCN, HWCNRange},
  1656. {kOpFormat_FRAC_Z, FRAC_ZRange},
  1657. {kOpFormat_NC1HWC0, NC1HWC0Range},
  1658. {kOpFormat_NDC1HWC0, NDC1HWC0Range},
  1659. {kOpFormat_C1HWNCoC0, C1HWNCOC0Range},
  1660. {kOpFormat_NC1HWC0_C04, NC1HWC04Range},
  1661. {kOpFormat_FRACTAL_Z_3D, FRAC_Z_3DRange},
  1662. {kOpFormat_FRACTAL_Z_C04, FRAC_ZC04Range}};
  1663. if (no_need_change.find(format) != no_need_change.end()) {
  1664. return ori_range;
  1665. }
  1666. // kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_FRAC_NZ no need pad range
  1667. if (format == kOpFormat_FRACTAL_ZN_LSTM) {
  1668. return FRAC_ZN_LSTMRange(ori_range, type);
  1669. }
  1670. if (format == kOpFormat_FRAC_NZ) {
  1671. return FRAC_NZRange(ori_range, type);
  1672. }
  1673. auto temp_range = ori_range;
  1674. if (ori_range.size() < kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
  1675. MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 4, so padding the range firstly";
  1676. temp_range = PaddingRangeTo4D(ori_range);
  1677. }
  1678. if (ori_range.size() < kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) {
  1679. MS_LOG(DEBUG) << "A special format:" << format << " with a range size less than 5, so padding the range firstly";
  1680. temp_range = PaddingRangeTo5D(ori_range);
  1681. }
  1682. auto iter = format_range_map.find(format);
  1683. if (iter == format_range_map.end()) {
  1684. MS_LOG(INFO) << "Can not find a supported format: " << format << ", using default range";
  1685. return ori_range;
  1686. }
  1687. return iter->second(temp_range, type);
  1688. }
  1689. RangePair ShapeRangeTransfer::NHWCRange(const RangePair &ori_range, const TypeId &) {
  1690. RangePair dst_range;
  1691. dst_range.push_back(ori_range[kN]);
  1692. dst_range.push_back(ori_range[kH]);
  1693. dst_range.push_back(ori_range[kW]);
  1694. dst_range.push_back(ori_range[kC]);
  1695. return dst_range;
  1696. }
  1697. RangePair ShapeRangeTransfer::HWCNRange(const RangePair &ori_range, const TypeId &) {
  1698. RangePair dst_range;
  1699. dst_range.push_back(ori_range[kH]);
  1700. dst_range.push_back(ori_range[kW]);
  1701. dst_range.push_back(ori_range[kC]);
  1702. dst_range.push_back(ori_range[kN]);
  1703. return dst_range;
  1704. }
  1705. RangePair ShapeRangeTransfer::NC1HWC04Range(const RangePair &ori_range, const TypeId &) {
  1706. RangePair dst_range;
  1707. const std::pair<int64_t, int64_t> c0 = {k4, k4};
  1708. const std::pair<int64_t, int64_t> c1 = {(ori_range[kC].first + k4 - 1) / k4, (ori_range[kC].second + k4 - 1) / k4};
  1709. dst_range.push_back(ori_range[kN]);
  1710. dst_range.push_back(c1);
  1711. dst_range.push_back(ori_range[kH]);
  1712. dst_range.push_back(ori_range[kW]);
  1713. dst_range.push_back(c0);
  1714. return dst_range;
  1715. }
  1716. RangePair ShapeRangeTransfer::FRAC_ZC04Range(const RangePair &ori_range, const TypeId &) {
  1717. RangePair dst_range;
  1718. const std::pair<int64_t, int64_t> c0 = {k4, k4};
  1719. const std::pair<int64_t, int64_t> c16 = {kNiSize, kNiSize};
  1720. const std::pair<int64_t, int64_t> first_dim = {
  1721. (c0.first * ori_range[kH].first * ori_range[kW].first + kNiSize - 1) / kNiSize,
  1722. (c0.second * ori_range[kH].second * ori_range[kW].second + kNiSize - 1) / kNiSize};
  1723. const std::pair<int64_t, int64_t> no = {(ori_range[kN].first + kNiSize - 1) / kNiSize,
  1724. (ori_range[kN].second + kNiSize - 1) / kNiSize};
  1725. dst_range.push_back(first_dim);
  1726. dst_range.push_back(no);
  1727. dst_range.push_back(c16);
  1728. dst_range.push_back(c16);
  1729. return dst_range;
  1730. }
  1731. RangePair ShapeRangeTransfer::FRAC_ZRange(const RangePair &ori_range, const TypeId &type) {
  1732. RangePair dst_range;
  1733. auto cube = GetCubeSizeByType(type);
  1734. const std::pair<int64_t, int64_t> c0 = {cube, cube};
  1735. const std::pair<int64_t, int64_t> cout16 = {((ori_range[kN].first + kNiSize - 1) / kNiSize) * kNiSize,
  1736. ((ori_range[kN].second + kNiSize - 1) / kNiSize) * kNiSize};
  1737. const std::pair<int64_t, int64_t> cin16 = {((ori_range[kC].first + cube - 1) / cube) * cube,
  1738. ((ori_range[kC].second + cube - 1) / cube) * cube};
  1739. const std::pair<int64_t, int64_t> r0 = {ori_range[kH].first * ori_range[kW].first * cin16.first / cube,
  1740. ori_range[kH].second * ori_range[kW].second * cin16.second / cube};
  1741. const std::pair<int64_t, int64_t> r1 = {cout16.first / kNiSize, cout16.second / kNiSize};
  1742. dst_range.push_back(r0);
  1743. dst_range.push_back(r1);
  1744. dst_range.push_back({kNiSize, kNiSize});
  1745. dst_range.push_back(c0);
  1746. return dst_range;
  1747. }
  1748. RangePair ShapeRangeTransfer::FRAC_NZRange(const RangePair &ori_range, const TypeId &type) {
  1749. RangePair dst_range;
  1750. auto cube = GetCubeSizeByType(type);
  1751. auto ori_size = ori_range.size();
  1752. if (ori_size < kDims2) {
  1753. MS_LOG(EXCEPTION) << "Format FracNZ can not support range size: " << ori_size;
  1754. } else {
  1755. (void)std::copy(ori_range.begin(), ori_range.end() - kDims2, std::back_inserter(dst_range));
  1756. }
  1757. const std::pair<int64_t, int64_t> c0 = {cube, cube};
  1758. const std::pair<int64_t, int64_t> w1 = {(ori_range[ori_size - 1].first - 1) / cube + 1,
  1759. (ori_range[ori_size - 1].second - 1) / cube + 1};
  1760. const std::pair<int64_t, int64_t> h1 = {(ori_range[ori_size - kDims2].first - 1) / kNiSize + 1,
  1761. (ori_range[ori_size - kDims2].second - 1) / kNiSize + 1};
  1762. dst_range.push_back(w1);
  1763. dst_range.push_back(h1);
  1764. dst_range.push_back({kNiSize, kNiSize});
  1765. dst_range.push_back(c0);
  1766. return dst_range;
  1767. }
  1768. RangePair ShapeRangeTransfer::NC1HWC0Range(const RangePair &ori_range, const TypeId &type) {
  1769. RangePair dst_range;
  1770. auto cube = GetCubeSizeByType(type);
  1771. const std::pair<int64_t, int64_t> c0 = {cube, cube};
  1772. const std::pair<int64_t, int64_t> c1 = {(ori_range[kC].first + cube - 1) / cube,
  1773. (ori_range[kC].second + cube - 1) / cube};
  1774. dst_range.push_back(ori_range[kN]);
  1775. dst_range.push_back(c1);
  1776. dst_range.push_back(ori_range[kH]);
  1777. dst_range.push_back(ori_range[kW]);
  1778. dst_range.push_back(c0);
  1779. return dst_range;
  1780. }
  1781. RangePair ShapeRangeTransfer::FRAC_ZN_LSTMRange(const RangePair &ori_range, const TypeId &) {
  1782. RangePair dst_range;
  1783. const std::pair<int64_t, int64_t> c0 = {k4, k4};
  1784. const std::pair<int64_t, int64_t> c16 = {k4, k4};
  1785. const std::pair<int64_t, int64_t> h = {ori_range[kN].first / c0.first, ori_range[kN].second / c0.second};
  1786. const std::pair<int64_t, int64_t> i = {ori_range[kC].first - h.first, ori_range[kC].second - h.second};
  1787. const std::pair<int64_t, int64_t> first_dim = {
  1788. (i.first + kCube16 - 1) / kCube16 + (h.first + kCube16 - 1) / kCube16,
  1789. (i.second + kCube16 - 1) / kCube16 + (h.second + kCube16 - 1) / kCube16};
  1790. const std::pair<int64_t, int64_t> second = {c0.first * ((h.first + kCube16 - 1) / kCube16),
  1791. c0.second * ((h.second + kCube16 - 1) / kCube16)};
  1792. dst_range.push_back(first_dim);
  1793. dst_range.push_back(second);
  1794. dst_range.push_back(c16);
  1795. dst_range.push_back(c16);
  1796. return dst_range;
  1797. }
  1798. RangePair ShapeRangeTransfer::NDC1HWC0Range(const RangePair &ori_range, const TypeId &type) {
  1799. RangePair dst_range;
  1800. auto cube = GetCubeSizeByType(type);
  1801. const std::pair<int64_t, int64_t> c0 = {cube, cube};
  1802. const std::pair<int64_t, int64_t> c1 = {(ori_range[C_ncdhw].first + cube - 1) / cube,
  1803. (ori_range[C_ncdhw].second + cube - 1) / cube};
  1804. dst_range.push_back(ori_range[N_ncdhw]);
  1805. dst_range.push_back(ori_range[D_ncdhw]);
  1806. dst_range.push_back(c1);
  1807. dst_range.push_back(ori_range[H_ncdhw]);
  1808. dst_range.push_back(ori_range[W_ncdhw]);
  1809. dst_range.push_back(c0);
  1810. return dst_range;
  1811. }
  1812. RangePair ShapeRangeTransfer::C1HWNCOC0Range(const RangePair &ori_range, const TypeId &type) {
  1813. RangePair dst_range;
  1814. auto cube = GetCubeSizeByType(type);
  1815. const std::pair<int64_t, int64_t> c0 = {cube, cube};
  1816. const std::pair<int64_t, int64_t> r1 = {(ori_range[kC].first - 1) / cube + 1, (ori_range[kC].second - 1) / cube + 1};
  1817. dst_range.push_back(r1);
  1818. dst_range.push_back(ori_range[kH]);
  1819. dst_range.push_back(ori_range[kW]);
  1820. dst_range.push_back(ori_range[kN]);
  1821. dst_range.push_back(c0);
  1822. dst_range.push_back(c0);
  1823. return dst_range;
  1824. }
  1825. RangePair ShapeRangeTransfer::FRAC_Z_3DRange(const RangePair &ori_range, const TypeId &type) {
  1826. RangePair dst_range;
  1827. auto cube = GetCubeSizeByType(type);
  1828. const std::pair<int64_t, int64_t> c0 = {cube, cube};
  1829. const std::pair<int64_t, int64_t> c1 = {(ori_range[C_ncdhw].first + cube - 1) / cube,
  1830. (ori_range[C_ncdhw].second + cube - 1) / cube};
  1831. const std::pair<int64_t, int64_t> n1 = {(ori_range[N_ncdhw].first + kNiSize - 1) / kNiSize,
  1832. (ori_range[N_ncdhw].second + kNiSize - 1) / kNiSize};
  1833. const int64_t r1_0 = ori_range[D_ncdhw].first * c1.first * ori_range[H_ncdhw].first * ori_range[W_ncdhw].first;
  1834. const int64_t r1_1 = ori_range[D_ncdhw].second * c1.second * ori_range[H_ncdhw].second * ori_range[W_ncdhw].second;
  1835. const std::pair<int64_t, int64_t> r1 = {r1_0, r1_1};
  1836. dst_range.push_back(r1);
  1837. dst_range.push_back(n1);
  1838. dst_range.push_back(c1);
  1839. dst_range.push_back(c0);
  1840. return dst_range;
  1841. }
  1842. } // namespace trans
  1843. } // namespace mindspore