You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans.cc 49 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/trans.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <utility>
  20. #include "utils/ms_utils.h"
  21. #include "abstract/utils.h"
  22. #include "backend/session/anf_runtime_algorithm.h"
  23. #include "backend/kernel_compiler/kernel.h"
  24. #include "runtime/device/convert_tensor_utils.h"
  25. #include "utils/convert_utils.h"
  26. #include "utils/log_adapter.h"
  27. #include "utils/utils.h"
  28. namespace mindspore {
  29. namespace trans {
  30. enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
  31. inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
  32. switch (size) {
  33. case 1:
  34. static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
  35. break;
  36. case 2:
  37. static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
  38. break;
  39. case 4:
  40. static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
  41. break;
  42. case 8:
  43. static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
  44. break;
  45. default:
  46. MS_LOG(EXCEPTION) << "Trans data not support size " << size;
  47. }
  48. }
  49. template <typename T>
  50. T DivCeil(T n1, T n2) {
  51. if (n2 != 0) {
  52. return (n1 + n2 - 1) / n2;
  53. }
  54. return 0;
  55. }
  56. enum DataTypeTransMode {
  57. FROM_FLOAT_TO_FLOAT16,
  58. FROM_FLOAT_TO_INT32,
  59. FROM_FLOAT16_TO_FLOAT,
  60. FROM_FLOAT16_TO_INT32,
  61. FROM_FLOAT16_TO_UINT8,
  62. FROM_INT32_TO_FLOAT,
  63. FROM_INT32_TO_FLOAT16,
  64. FROM_INT32_TO_UINT8,
  65. FROM_INT32_TO_INT8,
  66. FROM_INT32_TO_INT64,
  67. FROM_INT32_TO_BOOL,
  68. FROM_UINT8_TO_FLOAT,
  69. FROM_UINT8_TO_INT32,
  70. FROM_UINT8_TO_FLOAT16,
  71. FROM_INT8_TO_FLOAT,
  72. FROM_INT8_TO_FLOAT16,
  73. FROM_INT8_TO_INT32,
  74. FROM_INT64_TO_INT32,
  75. FROM_UINT16_TO_INT32,
  76. FROM_BOOL_TO_FLOAT,
  77. FROM_BOOL_TO_INT32,
  78. FROM_BOOL_TO_UINT8,
  79. FROM_BOOL_TO_FLOAT16,
  80. FROM_FLOAT64_TO_FLOAT32,
  81. FROM_FLOAT32_TO_FLOAT64
  82. };
  83. const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
  84. {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), FROM_FLOAT64_TO_FLOAT32},
  85. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), FROM_FLOAT32_TO_FLOAT64},
  86. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16},
  87. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32},
  88. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT},
  89. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32},
  90. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), FROM_FLOAT16_TO_UINT8},
  91. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT},
  92. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16},
  93. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8},
  94. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8},
  95. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt64), FROM_INT32_TO_INT64},
  96. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), FROM_INT32_TO_BOOL},
  97. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT},
  98. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32},
  99. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), FROM_UINT8_TO_FLOAT16},
  100. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT},
  101. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), FROM_INT8_TO_FLOAT16},
  102. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32},
  103. {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
  104. {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32},
  105. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), FROM_BOOL_TO_INT32},
  106. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), FROM_BOOL_TO_FLOAT},
  107. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), FROM_BOOL_TO_UINT8},
  108. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}};
  109. void CheckMemSize(const TypeIdArgs &args) {
  110. auto src_type_size = abstract::TypeIdSize(args.host_data_type);
  111. auto dst_type_size = abstract::TypeIdSize(args.device_data_type);
  112. if (src_type_size < 1 || dst_type_size < 1) {
  113. MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
  114. }
  115. if (args.data_size / src_type_size != args.host_shape_size) {
  116. MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
  117. }
  118. }
  119. template <typename SrcT, typename DstT>
  120. void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
  121. CheckMemSize(args);
  122. for (size_t idx = 0; idx != data_size; idx++) {
  123. SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
  124. static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
  125. }
  126. }
  127. template <typename SrcT>
  128. void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
  129. CheckMemSize(args);
  130. auto src_data = static_cast<const SrcT *>(args.data);
  131. auto half_data = static_cast<float16 *>(dst);
  132. for (size_t i = 0; i < data_size; i++) {
  133. half_data[i] = float16(src_data[i]);
  134. }
  135. }
  136. bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
  137. using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
  138. const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
  139. {FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
  140. {FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
  141. {FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
  142. {FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
  143. {FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
  144. {FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
  145. {FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
  146. {FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
  147. {FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
  148. {FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
  149. {FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
  150. {FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
  151. {FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
  152. {FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
  153. {FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  154. {FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
  155. {FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  156. {FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
  157. {FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
  158. {FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  159. {FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  160. {FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
  161. {FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
  162. if (mode == FROM_FLOAT_TO_FLOAT16) {
  163. device::FloatToHalf(dst, args.data, data_size);
  164. return true;
  165. } else if (mode == FROM_FLOAT16_TO_FLOAT) {
  166. device::HalfToFloat(dst, args.data, data_size);
  167. return true;
  168. }
  169. auto iter = cast_kernel_map.find(mode);
  170. if (iter != cast_kernel_map.end()) {
  171. iter->second(args, dst, data_size);
  172. return true;
  173. } else {
  174. MS_LOG(ERROR) << "Unsupported datatype trans";
  175. return false;
  176. }
  177. }
  178. size_t CubeSizeByType(const TypeId data_type) {
  179. const size_t default_error = 0;
  180. auto dt_size = abstract::TypeIdSize(data_type);
  181. if (dt_size < 1) {
  182. MS_LOG(ERROR) << "Illegal dtype.";
  183. return default_error;
  184. } else if (dt_size == 1) {
  185. return kCubeSize * 2;
  186. }
  187. return kCubeSize;
  188. }
  189. namespace {
  190. bool CheckDims(const std::vector<size_t> &shape) {
  191. if (shape.size() != kNchwDims) {
  192. MS_LOG(ERROR) << "Host shape dims should be 4";
  193. return false;
  194. }
  195. return true;
  196. }
  197. std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
  198. if (!CheckDims(shape)) {
  199. MS_LOG(EXCEPTION) << "Check dims failed.";
  200. }
  201. return shape;
  202. }
  203. std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
  204. if (!CheckDims(shape)) {
  205. MS_LOG(EXCEPTION) << "Ccheck dims failed.";
  206. }
  207. std::vector<size_t> device_shape;
  208. device_shape.push_back(shape[kN]);
  209. device_shape.push_back(shape[kH]);
  210. device_shape.push_back(shape[kW]);
  211. device_shape.push_back(shape[kC]);
  212. return device_shape;
  213. }
  214. std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
  215. if (!CheckDims(shape)) {
  216. MS_LOG(EXCEPTION) << "Check dims failed.";
  217. }
  218. std::vector<size_t> device_shape;
  219. device_shape.push_back(shape[kH]);
  220. device_shape.push_back(shape[kW]);
  221. device_shape.push_back(shape[kC]);
  222. device_shape.push_back(shape[kN]);
  223. return device_shape;
  224. }
  225. std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
  226. if (!CheckDims(shape)) {
  227. MS_LOG(EXCEPTION) << "Check dims failed.";
  228. }
  229. std::vector<size_t> device_shape;
  230. const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  231. const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  232. device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
  233. device_shape.push_back(cout16 / kCubeSize);
  234. device_shape.push_back(kCubeSize);
  235. device_shape.push_back(kCubeSize);
  236. return device_shape;
  237. }
  238. std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
  239. if (!CheckDims(shape)) {
  240. MS_LOG(EXCEPTION) << "Check dims failed.";
  241. }
  242. std::vector<size_t> device_shape;
  243. const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
  244. const size_t C0 = kCubeSize;
  245. device_shape.push_back(shape[kN]);
  246. device_shape.push_back(C1);
  247. device_shape.push_back(shape[kH]);
  248. device_shape.push_back(shape[kW]);
  249. device_shape.push_back(C0);
  250. return device_shape;
  251. }
  252. std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
  253. // NCDHW
  254. if (shape.size() != 5) {
  255. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  256. }
  257. std::vector<size_t> device_shape;
  258. const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
  259. const size_t C0 = kCubeSize;
  260. device_shape.push_back(shape[0]);
  261. device_shape.push_back(shape[2]);
  262. device_shape.push_back(C1);
  263. device_shape.push_back(shape[3]);
  264. device_shape.push_back(shape[4]);
  265. device_shape.push_back(C0);
  266. return device_shape;
  267. }
  268. std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
  269. // NCDHW -> Frac_Z_3D
  270. if (shape.size() != 5) {
  271. MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
  272. }
  273. std::vector<size_t> device_shape;
  274. const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
  275. const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
  276. device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
  277. device_shape.push_back(N1);
  278. device_shape.push_back(kCubeSize);
  279. device_shape.push_back(kCubeSize);
  280. return device_shape;
  281. }
  282. std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
  283. if (!CheckDims(shape)) {
  284. MS_LOG(EXCEPTION) << "Check dims failed.";
  285. }
  286. std::vector<size_t> device_shape;
  287. device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
  288. device_shape.push_back(shape[kH]);
  289. device_shape.push_back(shape[kW]);
  290. device_shape.push_back(shape[kN]);
  291. device_shape.push_back(kCubeSize);
  292. device_shape.push_back(kCubeSize);
  293. return device_shape;
  294. }
  295. std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
  296. if (!CheckDims(shape)) {
  297. MS_LOG(EXCEPTION) << "Check dims failed.";
  298. }
  299. std::vector<size_t> device_shape;
  300. const size_t c0 = 4;
  301. auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
  302. auto no = DivCeil(shape.at(kN), kCubeSize);
  303. device_shape.push_back(first_dim);
  304. device_shape.push_back(no);
  305. device_shape.push_back(kCubeSize);
  306. device_shape.push_back(kCubeSize);
  307. return device_shape;
  308. }
  309. std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
  310. if (!CheckDims(shape)) {
  311. MS_LOG(EXCEPTION) << "Check dims failed.";
  312. }
  313. std::vector<size_t> device_shape;
  314. const size_t C1 = 1;
  315. const size_t C0 = 4;
  316. device_shape.push_back(shape[kN]);
  317. device_shape.push_back(C1);
  318. device_shape.push_back(shape[kH]);
  319. device_shape.push_back(shape[kW]);
  320. device_shape.push_back(C0);
  321. return device_shape;
  322. }
  323. std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
  324. if (shape.size() < kNdhwc) {
  325. MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
  326. }
  327. return shape;
  328. }
  329. std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
  330. std::vector<size_t> shape_4d(kNchwDims, 1);
  331. switch (shape.size()) {
  332. case 0:
  333. return shape_4d;
  334. case 1:
  335. shape_4d[kC] = shape[kN];
  336. break;
  337. case 2:
  338. shape_4d[kC] = shape[kN];
  339. shape_4d[kH] = shape[kC];
  340. break;
  341. case 3:
  342. shape_4d[kC] = shape[kN];
  343. shape_4d[kH] = shape[kC];
  344. shape_4d[kW] = shape[kH];
  345. break;
  346. case 4:
  347. std::copy(shape.begin(), shape.end(), shape_4d.begin());
  348. break;
  349. default:
  350. MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
  351. }
  352. return shape_4d;
  353. }
  354. } // namespace
  355. bool IsNeedPadding(const std::string &format, const size_t shape_size) {
  356. if (shape_size == 0) {
  357. return false;
  358. }
  359. if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
  360. return false;
  361. } else if (shape_size < kNchwDims) {
  362. return true;
  363. }
  364. return false;
  365. }
  366. ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
  367. MS_EXCEPTION_IF_NULL(node);
  368. ShapeVector shape;
  369. std::vector<size_t> host_shape;
  370. if (node->isa<ValueNode>()) {
  371. auto value_node = node->cast<ValueNodePtr>();
  372. MS_EXCEPTION_IF_NULL(value_node);
  373. auto node_value = value_node->value();
  374. MS_EXCEPTION_IF_NULL(node_value);
  375. auto tensor = node_value->cast<tensor::TensorPtr>();
  376. if (tensor == nullptr) {
  377. MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
  378. }
  379. auto shape_temp = tensor->shape();
  380. (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), LongToSize);
  381. if (host_shape.empty()) {
  382. host_shape.push_back(1);
  383. }
  384. } else {
  385. host_shape = AnfAlgo::GetOutputInferShape(node, index);
  386. }
  387. if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) {
  388. host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index));
  389. }
  390. std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
  391. return shape;
  392. }
  393. std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) {
  394. if (padding_axis.empty() || shape.size() != padding_axis.size()) {
  395. return PaddingShapeTo4dByDefault(shape);
  396. }
  397. std::vector<size_t> shape_4d(kNchwDims, 1);
  398. for (size_t index = 0; index < padding_axis.size(); index++) {
  399. shape_4d[padding_axis[index]] = shape[index];
  400. }
  401. return shape_4d;
  402. }
  403. std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
  404. using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
  405. const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
  406. {kOpFormat_NHWC, NhwcDeviceShape},
  407. {kOpFormat_HWCN, HwchDeviceShape},
  408. {kOpFormat_FRAC_Z, FracZDeviceShape},
  409. {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
  410. {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
  411. {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
  412. {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
  413. {kOpFormat_NCDHW, NcdhwDeviceShape},
  414. {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
  415. {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}};
  416. if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
  417. return shape;
  418. }
  419. auto temp_shape = shape;
  420. std::vector<size_t> device_shape;
  421. if (format == kOpFormat_FRAC_NZ) {
  422. if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
  423. // For [1] and [1024] shape we can trait it as NZ shape
  424. return shape;
  425. }
  426. if (shape.size() < 2) {
  427. MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
  428. } else {
  429. (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
  430. }
  431. auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
  432. auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
  433. device_shape.push_back(w1);
  434. device_shape.push_back(h1);
  435. device_shape.push_back(kCubeSize);
  436. device_shape.push_back(kCubeSize);
  437. return device_shape;
  438. } else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
  439. const size_t c0 = 4;
  440. const size_t h = shape.at(kN) / c0;
  441. const size_t i = shape.at(kC) - h;
  442. const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
  443. const size_t second = c0 * DivCeil(h, kCubeSize);
  444. device_shape.push_back(first);
  445. device_shape.push_back(second);
  446. device_shape.push_back(kCubeSize);
  447. device_shape.push_back(kCubeSize);
  448. return device_shape;
  449. }
  450. if (shape.size() != kNchwDims && shape.size() != 5) {
  451. MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
  452. temp_shape = PaddingShapeTo4dByDefault(shape);
  453. }
  454. auto iter = device_shape_map.find(format);
  455. if (iter == device_shape_map.end()) {
  456. MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
  457. }
  458. return iter->second(temp_shape);
  459. }
  460. bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
  461. if (args.host_shape.size() != kNchwDims) {
  462. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  463. return false;
  464. }
  465. MS_EXCEPTION_IF_NULL(size);
  466. MS_EXCEPTION_IF_NULL(total_size);
  467. *size = abstract::TypeIdSize(args.src_data_type);
  468. if (*size < 1) {
  469. MS_LOG(ERROR) << "Illegal dtype.";
  470. return false;
  471. }
  472. *total_size = abstract::ShapeSize(args.device_shape) * (*size);
  473. if (*total_size != args.device_size) {
  474. MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
  475. return false;
  476. }
  477. return true;
  478. }
  479. bool TransDataType(const TypeIdArgs &args, void *result) {
  480. MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
  481. << TypeIdLabel(args.device_data_type);
  482. MS_EXCEPTION_IF_NULL(result);
  483. std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
  484. auto iter = mode_map.find(type_info);
  485. if (iter == mode_map.end()) {
  486. MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
  487. << ", dst_type:" << TypeIdLabel(args.device_data_type);
  488. return false;
  489. }
  490. auto trans_mode = iter->second;
  491. if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
  492. MS_LOG(ERROR) << "Failed to trans datatype..";
  493. return false;
  494. }
  495. return true;
  496. }
  497. bool TransFormat(const FormatArgs &args, void *result) {
  498. MS_LOG(DEBUG) << "Start trans format.";
  499. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  500. MS_LOG(ERROR) << "Invalid datatype..";
  501. return false;
  502. }
  503. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  504. return NchwTo4D(args, result);
  505. }
  506. auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
  507. if (iter == kTransFormatMapOfHostToDevice.end()) {
  508. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  509. }
  510. return iter->second(args, result);
  511. }
  512. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
  513. const std::map<std::string, FormatTransfer> format_trans_map{
  514. {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
  515. {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
  516. {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw},
  517. {kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}};
  518. MS_LOG(DEBUG) << "Start trans format.";
  519. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  520. MS_LOG(ERROR) << "Invalid datatype..";
  521. return false;
  522. }
  523. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  524. return ToNchw(args, result);
  525. }
  526. auto iter = format_trans_map.find(args.device_format);
  527. if (iter == format_trans_map.end()) {
  528. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  529. }
  530. return iter->second(args, result);
  531. }
  532. bool NchwTo4D(const FormatArgs &args, void *result) {
  533. // trans nchw to 4d
  534. MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
  535. MS_EXCEPTION_IF_NULL(result);
  536. size_t size = 0;
  537. size_t total_size = 0;
  538. if (!CheckArgs(args, &size, &total_size)) {
  539. MS_LOG(ERROR) << "Check args failed.";
  540. return false;
  541. }
  542. auto n = args.host_shape[kN];
  543. auto c = args.host_shape[kC];
  544. auto h = args.host_shape[kH];
  545. auto w = args.host_shape[kW];
  546. for (size_t ni = 0; ni < n; ni++) {
  547. for (size_t ci = 0; ci < c; ci++) {
  548. for (size_t hi = 0; hi < h; hi++) {
  549. for (size_t wi = 0; wi < w; wi++) {
  550. auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  551. auto dst_idx = 0;
  552. if (args.device_format == kOpFormat_NHWC) {
  553. dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  554. } else if (args.device_format == kOpFormat_HWCN) {
  555. dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  556. }
  557. SetData(size, false, src_idx, dst_idx, args, result);
  558. }
  559. }
  560. }
  561. }
  562. return true;
  563. }
  564. bool ToNchw(const FormatArgs &args, void *result) {
  565. MS_LOG(DEBUG) << "Trans format to nchw from 4d.";
  566. MS_EXCEPTION_IF_NULL(result);
  567. size_t size = 0;
  568. size_t total_size = 0;
  569. if (!CheckArgs(args, &size, &total_size)) {
  570. MS_LOG(ERROR) << "Check args failed.";
  571. return false;
  572. }
  573. auto n = args.host_shape[kN];
  574. auto c = args.host_shape[kC];
  575. auto h = args.host_shape[kH];
  576. auto w = args.host_shape[kW];
  577. for (size_t ni = 0; ni < n; ni++) {
  578. for (size_t ci = 0; ci < c; ci++) {
  579. for (size_t hi = 0; hi < h; hi++) {
  580. for (size_t wi = 0; wi < w; wi++) {
  581. auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  582. auto src_idx = 0;
  583. if (args.device_format == kOpFormat_NHWC) {
  584. src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  585. } else if (args.device_format == kOpFormat_HWCN) {
  586. src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  587. }
  588. SetData(size, false, src_idx, dst_idx, args, result);
  589. }
  590. }
  591. }
  592. }
  593. return true;
  594. }
  595. bool NchwToFracZ(const FormatArgs &args, void *result) {
  596. MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
  597. MS_EXCEPTION_IF_NULL(result);
  598. if (args.host_shape.size() != kNchwDims) {
  599. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  600. return false;
  601. }
  602. auto size = abstract::TypeIdSize(args.src_data_type);
  603. if (size < 1) {
  604. MS_LOG(ERROR) << "Illegal dtype.";
  605. return false;
  606. }
  607. auto n = args.host_shape[kN];
  608. auto c = args.host_shape[kC];
  609. auto h = args.host_shape[kH];
  610. auto w = args.host_shape[kW];
  611. auto c0 = CubeSizeByType(args.src_data_type);
  612. if (c0 < 1) {
  613. MS_LOG(ERROR) << "Illegal dtype.";
  614. return false;
  615. }
  616. auto c1 = DivCeil(c, c0);
  617. auto hw = h * w;
  618. auto chw = c * hw;
  619. auto hwc0 = hw * c0;
  620. auto nchw = n * chw;
  621. auto hf_cnt = DivCeil(n, kCubeSize);
  622. auto vf_cnt = c1 * hw;
  623. auto fractal_ele_cnt = c0 * kCubeSize;
  624. auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
  625. auto dst_size = total_ele_cnt * size;
  626. if (dst_size != args.device_size) {
  627. MS_LOG(ERROR) << "Illegal total data size."
  628. << "dst size is :" << dst_size << "device size is :" << args.device_size;
  629. return false;
  630. }
  631. for (size_t vfi = 0; vfi < vf_cnt; vfi++) {
  632. auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
  633. for (size_t hfi = 0; hfi < hf_cnt; hfi++) {
  634. auto gfi = vf_base_i + hfi; // global fractal matrix index
  635. auto src_n_offset = hfi * chw * kCubeSize;
  636. auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
  637. for (size_t row = 0; row < c0; row++) {
  638. auto src_ci = vfi / hw * c0 + row;
  639. auto src_row_offset = src_f_offset + row * hw;
  640. for (size_t col = 0; col < kCubeSize; col++) {
  641. auto src_ni = hfi * kCubeSize + col;
  642. auto src_idx = src_row_offset + chw * col;
  643. auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
  644. auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
  645. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  646. }
  647. }
  648. }
  649. }
  650. return true;
  651. }
  652. bool FracZToNchw(const FormatArgs &args, void *result) {
  653. MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
  654. MS_EXCEPTION_IF_NULL(result);
  655. if (args.host_shape.size() != kNchwDims) {
  656. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  657. return false;
  658. }
  659. auto size = abstract::TypeIdSize(args.src_data_type);
  660. if (size < 1) {
  661. MS_LOG(ERROR) << "Illegal dtype.";
  662. return false;
  663. }
  664. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  665. if (total_size != args.device_size) {
  666. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  667. return false;
  668. }
  669. auto n0 = args.device_shape.at(1);
  670. auto ni = args.device_shape.at(2);
  671. auto c0 = args.device_shape.at(3);
  672. auto n = args.host_shape[kN];
  673. auto c = args.host_shape[kC];
  674. auto h = args.host_shape[kH];
  675. auto w = args.host_shape[kW];
  676. auto nc = ni * n0;
  677. auto ncc0 = nc * c0;
  678. auto wncc0 = w * ncc0;
  679. auto hwncc0 = h * wncc0;
  680. auto hw = h * w;
  681. auto chw = c * hw;
  682. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  683. size_t n_head_addr = n_idx * chw;
  684. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  685. size_t c_head_addr = n_head_addr + c_idx * hw;
  686. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  687. size_t h_head_addr = c_head_addr + h_idx * w;
  688. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  689. size_t dst_idx = h_head_addr + w_idx;
  690. size_t c1_idx = c_idx / c0;
  691. size_t c0_idx = c_idx % c0;
  692. size_t nc_idx = n_idx;
  693. size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
  694. SetData(size, false, src_idx, dst_idx, args, result);
  695. }
  696. }
  697. }
  698. }
  699. return true;
  700. }
  701. bool NchwToFracZc04(const FormatArgs &args, void *result) {
  702. // trans nchw to FracZc04
  703. MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
  704. MS_EXCEPTION_IF_NULL(result);
  705. size_t size = 0;
  706. size_t total_size = 0;
  707. if (!CheckArgs(args, &size, &total_size)) {
  708. MS_LOG(ERROR) << "Check args failed.";
  709. return false;
  710. }
  711. auto cube = kCubeSize;
  712. auto n = args.host_shape[kN];
  713. auto c = args.host_shape[kC];
  714. auto h = args.host_shape[kH];
  715. auto w = args.host_shape[kW];
  716. const size_t c0 = 4;
  717. auto c1 = DivCeil(c, c0);
  718. auto hwc0 = h * w * c0;
  719. auto hwc = h * w * c;
  720. auto nhwc = n * h * w * c;
  721. auto n_cnt = DivCeil(n, cube);
  722. auto v_cnt = DivCeil(h * w * c0 * c1, cube);
  723. size_t dst_idx = 0;
  724. for (size_t vi = 0; vi < v_cnt; vi++) {
  725. for (size_t ni = 0; ni < n_cnt; ni++) {
  726. for (size_t col = 0; col < cube; col++) {
  727. for (size_t row = 0; row < cube; row++) {
  728. size_t cur_cube_n = cube * ni + col;
  729. size_t cur_cube_c1hwc0 = cube * vi + row;
  730. auto desc_g = cur_cube_n / n;
  731. auto desc_n = cur_cube_n % n;
  732. auto desc_c1 = cur_cube_c1hwc0 / hwc0;
  733. auto desc_c0 = cur_cube_c1hwc0 % c0;
  734. auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
  735. auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
  736. auto c_idx = desc_c1 * c0 + desc_c0;
  737. auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
  738. auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
  739. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  740. dst_idx++;
  741. }
  742. }
  743. }
  744. }
  745. return true;
  746. }
  747. bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
  748. MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
  749. return NchwToNc1hwc0(args, result);
  750. }
  751. bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
  752. MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
  753. return Nc1hwc0ToNchw(args, result);
  754. }
  755. bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
  756. MS_EXCEPTION_IF_NULL(hw_shape);
  757. if (host_shape.empty()) {
  758. MS_LOG(ERROR) << "Size of vector is 0.";
  759. return false;
  760. }
  761. switch (host_shape.size()) {
  762. case 1:
  763. hw_shape->push_back(1);
  764. hw_shape->push_back(1);
  765. hw_shape->push_back(host_shape[0]);
  766. return true;
  767. default:
  768. auto size = host_shape.size();
  769. if (size < 2) {
  770. MS_LOG(ERROR) << "Illegal size.";
  771. return false;
  772. }
  773. size_t times = 1;
  774. for (size_t i = 0; i != size - 2; i++) {
  775. times *= host_shape[i];
  776. }
  777. hw_shape->push_back(times);
  778. hw_shape->push_back(host_shape[size - 2]);
  779. hw_shape->push_back(host_shape[size - 1]);
  780. return true;
  781. }
  782. }
  783. bool NchwToFracNz(const FormatArgs &args, void *result) {
  784. MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
  785. MS_EXCEPTION_IF_NULL(result);
  786. std::vector<size_t> hw_shape;
  787. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  788. MS_LOG(ERROR) << "Trans shape failed..";
  789. return false;
  790. }
  791. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  792. MS_LOG(ERROR) << "Invalid shape size.";
  793. return false;
  794. }
  795. auto size = abstract::TypeIdSize(args.src_data_type);
  796. if (size < 1) {
  797. MS_LOG(ERROR) << "Illegal dtype";
  798. return false;
  799. }
  800. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  801. if (dst_size != args.device_size) {
  802. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  803. return false;
  804. }
  805. auto times = hw_shape.at(0);
  806. auto h = hw_shape.at(1);
  807. auto w = hw_shape.at(2);
  808. auto hw = h * w;
  809. auto shape_size = args.device_shape.size();
  810. auto w1 = args.device_shape[shape_size - 4];
  811. auto h1 = args.device_shape[shape_size - 3];
  812. auto h0 = args.device_shape[shape_size - 2];
  813. auto w0 = args.device_shape[shape_size - 1];
  814. auto h1h0w0 = h1 * h0 * w0;
  815. auto w1h1h0w0 = w1 * h1h0w0;
  816. auto num_w1 = w / w0;
  817. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  818. auto times_head = times_idx * w1h1h0w0;
  819. auto src_times_head = times_idx * hw;
  820. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  821. auto h1h0_head = times_head + h1h0_idx * w0;
  822. auto src_h_head = src_times_head + h1h0_idx * w;
  823. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  824. for (size_t i = 0; i < w0; ++i) {
  825. size_t src_idx = src_h_head + w1_idx * w0 + i;
  826. size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
  827. SetData(size, false, src_idx, dst_idx, args, result);
  828. }
  829. }
  830. auto w1_head = num_w1 * w0;
  831. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  832. auto src_w_idx = w1_head + w0_idx;
  833. size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  834. size_t src_idx = src_h_head + src_w_idx;
  835. SetData(size, false, src_idx, dst_idx, args, result);
  836. }
  837. }
  838. }
  839. return true;
  840. }
  841. bool FracNzToNchw(const FormatArgs &args, void *result) {
  842. MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
  843. MS_EXCEPTION_IF_NULL(result);
  844. std::vector<size_t> hw_shape;
  845. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  846. MS_LOG(ERROR) << "Trans shape failed..";
  847. return false;
  848. }
  849. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  850. MS_LOG(ERROR) << "Invalid shape size.";
  851. return false;
  852. }
  853. auto size = abstract::TypeIdSize(args.src_data_type);
  854. if (size < 1) {
  855. MS_LOG(ERROR) << "Illegal dtype";
  856. return false;
  857. }
  858. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  859. if (dst_size != args.device_size) {
  860. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  861. return false;
  862. }
  863. auto times = hw_shape.at(0);
  864. auto h = hw_shape.at(1);
  865. auto w = hw_shape.at(2);
  866. auto hw = h * w;
  867. auto shape_size = args.device_shape.size();
  868. auto w1 = args.device_shape[shape_size - 4];
  869. auto h1 = args.device_shape[shape_size - 3];
  870. auto h0 = args.device_shape[shape_size - 2];
  871. auto w0 = args.device_shape[shape_size - 1];
  872. auto h1h0w0 = h1 * h0 * w0;
  873. auto w1h1h0w0 = w1 * h1h0w0;
  874. auto num_w1 = w / w0;
  875. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  876. auto times_head = times_idx * w1h1h0w0;
  877. auto src_times_head = times_idx * hw;
  878. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  879. auto h1h0_head = times_head + h1h0_idx * w0;
  880. auto src_h_head = src_times_head + h1h0_idx * w;
  881. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  882. for (size_t i = 0; i < w0; ++i) {
  883. size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
  884. size_t dst_idx = src_h_head + w1_idx * w0 + i;
  885. SetData(size, false, src_idx, dst_idx, args, result);
  886. }
  887. }
  888. auto w1_head = num_w1 * w0;
  889. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  890. auto src_w_idx = w1_head + w0_idx;
  891. size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  892. size_t dst_idx = src_h_head + src_w_idx;
  893. SetData(size, false, src_idx, dst_idx, args, result);
  894. }
  895. }
  896. }
  897. return true;
  898. }
  899. bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
  900. MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
  901. MS_EXCEPTION_IF_NULL(result);
  902. if (args.host_shape.size() != kNchwDims) {
  903. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  904. return false;
  905. }
  906. auto size = abstract::TypeIdSize(args.src_data_type);
  907. if (size < 1) {
  908. MS_LOG(ERROR) << "Illegal dtype.";
  909. return false;
  910. }
  911. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  912. if (total_size != args.device_size) {
  913. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  914. return false;
  915. }
  916. auto n = args.host_shape[kN];
  917. auto c = args.host_shape[kC];
  918. auto h = args.host_shape[kH];
  919. auto w = args.host_shape[kW];
  920. auto c0 = CubeSizeByType(args.src_data_type);
  921. if (c0 < 1) {
  922. MS_LOG(ERROR) << "Illegal dtype.";
  923. return false;
  924. }
  925. if (args.device_format == kOpFormat_NC1HWC0_C04) {
  926. c0 = 4;
  927. }
  928. auto c1 = DivCeil(c, c0);
  929. auto hw = h * w;
  930. auto chw = c * hw;
  931. auto c1hwc0 = c1 * hw * c0;
  932. auto wc0 = w * c0;
  933. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  934. size_t n_head_addr = n_idx * c1hwc0;
  935. for (size_t c1_idx = 0; c1_idx < c1; c1_idx++) {
  936. size_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
  937. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  938. size_t h_head_addr = c1_head_addr + h_idx * wc0;
  939. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  940. size_t w_head_addr = h_head_addr + w_idx * c0;
  941. for (size_t c0_idx = 0; c0_idx < c0; c0_idx++) {
  942. size_t dst_idx = c0_idx + w_head_addr;
  943. size_t c_idx = c0_idx + c1_idx * c0;
  944. size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
  945. auto pad_zero = c_idx >= c;
  946. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  947. }
  948. }
  949. }
  950. }
  951. }
  952. return true;
  953. }
  954. bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
  955. MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
  956. MS_EXCEPTION_IF_NULL(result);
  957. if (args.host_shape.size() != kNchwDims) {
  958. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  959. return false;
  960. }
  961. auto size = abstract::TypeIdSize(args.src_data_type);
  962. if (size < 1) {
  963. MS_LOG(ERROR) << "Illegal dtype.";
  964. return false;
  965. }
  966. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  967. if (total_size != args.device_size) {
  968. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  969. return false;
  970. }
  971. auto n = args.host_shape[kN];
  972. auto c = args.host_shape[kC];
  973. auto h = args.host_shape[kH];
  974. auto w = args.host_shape[kW];
  975. auto c1 = args.device_shape[1];
  976. auto c0 = args.device_shape[4];
  977. auto hw = h * w;
  978. auto chw = c * hw;
  979. auto wc0 = w * c0;
  980. auto hwc0 = h * wc0;
  981. auto c1hwc0 = c1 * hwc0;
  982. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  983. size_t n_head_addr = n_idx * chw;
  984. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  985. size_t c_head_addr = n_head_addr + c_idx * hw;
  986. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  987. size_t h_head_addr = c_head_addr + h_idx * w;
  988. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  989. size_t dst_idx = h_head_addr + w_idx;
  990. size_t c1_idx = c_idx / c0;
  991. size_t c0_idx = c_idx % c0;
  992. size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
  993. SetData(size, false, src_idx, dst_idx, args, result);
  994. }
  995. }
  996. }
  997. }
  998. return true;
  999. }
  1000. bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
  1001. // trans nchw to c1hwncoc0
  1002. MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
  1003. MS_EXCEPTION_IF_NULL(result);
  1004. size_t size = 0;
  1005. size_t total_size = 0;
  1006. if (!CheckArgs(args, &size, &total_size)) {
  1007. MS_LOG(ERROR) << "Check args failed.";
  1008. return false;
  1009. }
  1010. auto n = args.host_shape[kN];
  1011. auto c = args.host_shape[kC];
  1012. auto h = args.host_shape[kH];
  1013. auto w = args.host_shape[kW];
  1014. const int co_idx = 4;
  1015. const int c0_idx = 5;
  1016. auto c1 = args.device_shape[0];
  1017. auto co = args.device_shape[co_idx];
  1018. auto c0 = args.device_shape[c0_idx];
  1019. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  1020. for (size_t h_i = 0; h_i < h; h_i++) {
  1021. for (size_t w_i = 0; w_i < w; w_i++) {
  1022. for (size_t n_i = 0; n_i < n; n_i++) {
  1023. for (size_t co_i = 0; co_i < co; co_i++) {
  1024. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  1025. size_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
  1026. co_i * c0 + c0_i;
  1027. size_t c_i = c0_i + c1_i * c0;
  1028. size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1029. auto pad_zero = !(c_i < c && c0_i == co_i);
  1030. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1031. }
  1032. }
  1033. }
  1034. }
  1035. }
  1036. }
  1037. return true;
  1038. }
  1039. bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
  1040. // trans c1hwncoc0 to nchw
  1041. MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
  1042. MS_EXCEPTION_IF_NULL(result);
  1043. size_t size = 0;
  1044. size_t total_size = 0;
  1045. if (!CheckArgs(args, &size, &total_size)) {
  1046. MS_LOG(ERROR) << "Check args failed.";
  1047. return false;
  1048. }
  1049. auto n = args.host_shape[kN];
  1050. auto c = args.host_shape[kC];
  1051. auto h = args.host_shape[kH];
  1052. auto w = args.host_shape[kW];
  1053. const int co_idx = 4;
  1054. const int c0_idx = 5;
  1055. auto co = args.device_shape[co_idx];
  1056. auto c0 = args.device_shape[c0_idx];
  1057. for (size_t n_i = 0; n_i < n; n_i++) {
  1058. for (size_t c_i = 0; c_i < c; c_i++) {
  1059. for (size_t h_i = 0; h_i < h; h_i++) {
  1060. for (size_t w_i = 0; w_i < w; w_i++) {
  1061. size_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1062. size_t c1_i = c_i / kCubeSize;
  1063. size_t c0_i = c_i % kCubeSize;
  1064. size_t co_i = c0_i;
  1065. size_t src_idx =
  1066. c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
  1067. SetData(size, false, src_idx, dst_idx, args, result);
  1068. }
  1069. }
  1070. }
  1071. }
  1072. return true;
  1073. }
  1074. bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
  1075. MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
  1076. MS_EXCEPTION_IF_NULL(result);
  1077. if (args.host_shape.size() != 5) {
  1078. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1079. return false;
  1080. }
  1081. auto size = abstract::TypeIdSize(args.src_data_type);
  1082. if (size < 1) {
  1083. MS_LOG(ERROR) << "Illegal dtype.";
  1084. return false;
  1085. }
  1086. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1087. if (total_size != args.device_size) {
  1088. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1089. return false;
  1090. }
  1091. auto n = args.host_shape[0];
  1092. auto c = args.host_shape[1];
  1093. auto d = args.host_shape[2];
  1094. auto h = args.host_shape[3];
  1095. auto w = args.host_shape[4];
  1096. auto c1 = args.device_shape[2];
  1097. auto c0 = args.device_shape[5];
  1098. const size_t cdhw = c * d * h * w;
  1099. const size_t dhw = d * h * w;
  1100. const size_t hw = h * w;
  1101. const size_t dc1hwc0 = d * c1 * h * w * c0;
  1102. const size_t c1hwc0 = c1 * h * w * c0;
  1103. const size_t hwc0 = h * w * c0;
  1104. const size_t wc0 = w * c0;
  1105. for (size_t n_i = 0; n_i < n; n_i++) {
  1106. size_t n_head = n_i * cdhw;
  1107. for (size_t c_i = 0; c_i < c; c_i++) {
  1108. size_t c_head = n_head + c_i * dhw;
  1109. for (size_t d_i = 0; d_i < d; d_i++) {
  1110. size_t d_head = c_head + d_i * hw;
  1111. for (size_t h_i = 0; h_i < h; h_i++) {
  1112. size_t h_head = d_head + h_i * w;
  1113. for (size_t w_i = 0; w_i < w; w_i++) {
  1114. size_t dst_i = h_head + w_i;
  1115. size_t c1_i = c_i / c0;
  1116. size_t c0_i = c_i % c0;
  1117. auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
  1118. SetData(size, false, src_idx, dst_i, args, result);
  1119. }
  1120. }
  1121. }
  1122. }
  1123. }
  1124. return true;
  1125. }
  1126. bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
  1127. MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
  1128. MS_EXCEPTION_IF_NULL(result);
  1129. if (args.host_shape.size() != 5) {
  1130. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1131. return false;
  1132. }
  1133. auto size = abstract::TypeIdSize(args.src_data_type);
  1134. if (size < 1) {
  1135. MS_LOG(ERROR) << "Illegal dtype.";
  1136. return false;
  1137. }
  1138. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1139. if (total_size != args.device_size) {
  1140. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1141. return false;
  1142. }
  1143. auto n = args.host_shape[0];
  1144. auto c = args.host_shape[1];
  1145. auto d = args.host_shape[2];
  1146. auto h = args.host_shape[3];
  1147. auto w = args.host_shape[4];
  1148. auto c0 = kCubeSize;
  1149. auto c1 = DivCeil(c, c0);
  1150. const size_t cdhw = c * d * h * w;
  1151. const size_t dhw = d * h * w;
  1152. const size_t hw = h * w;
  1153. const size_t dc1hwc0 = d * c1 * h * w * c0;
  1154. const size_t c1hwc0 = c1 * h * w * c0;
  1155. const size_t hwc0 = h * w * c0;
  1156. const size_t wc0 = w * c0;
  1157. for (size_t n_i = 0; n_i < n; n_i++) {
  1158. size_t n_head = n_i * dc1hwc0;
  1159. for (size_t d_i = 0; d_i < d; d_i++) {
  1160. size_t d_head = n_head + d_i * c1hwc0;
  1161. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  1162. size_t c1_head = d_head + c1_i * hwc0;
  1163. for (size_t h_i = 0; h_i < h; h_i++) {
  1164. size_t h_head = c1_head + h_i * wc0;
  1165. for (size_t w_i = 0; w_i < w; w_i++) {
  1166. size_t w_head = h_head + w_i * c0;
  1167. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  1168. size_t dst_i = c0_i + w_head;
  1169. size_t c_i = c0_i + c1_i * c0;
  1170. size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
  1171. auto pad_zero = c_i >= c;
  1172. SetData(size, pad_zero, src_i, dst_i, args, result);
  1173. }
  1174. }
  1175. }
  1176. }
  1177. }
  1178. }
  1179. return true;
  1180. }
  1181. bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
  1182. MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
  1183. MS_EXCEPTION_IF_NULL(result);
  1184. if (args.host_shape.size() != 5) {
  1185. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1186. return false;
  1187. }
  1188. auto size = abstract::TypeIdSize(args.src_data_type);
  1189. if (size < 1) {
  1190. MS_LOG(ERROR) << "Illegal dtype.";
  1191. return false;
  1192. }
  1193. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1194. if (total_size != args.device_size) {
  1195. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1196. return false;
  1197. }
  1198. auto n = args.host_shape[0];
  1199. auto c = args.host_shape[1];
  1200. auto d = args.host_shape[2];
  1201. auto h = args.host_shape[3];
  1202. auto w = args.host_shape[4];
  1203. auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
  1204. auto c0 = CubeSizeByType(args.src_data_type);
  1205. auto c1 = DivCeil(c, c0);
  1206. auto hw = h * w;
  1207. auto dhw = d * hw;
  1208. auto cdhw = c * dhw;
  1209. auto n1n0c0 = n1n0 * c0;
  1210. auto wn1n0c0 = w * n1n0c0;
  1211. auto hwn1n0c0 = h * wn1n0c0;
  1212. auto dhwn1n0c0 = d * hwn1n0c0;
  1213. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  1214. for (size_t d_i = 0; d_i < d; d_i++) {
  1215. for (size_t h_i = 0; h_i < h; h_i++) {
  1216. for (size_t w_i = 0; w_i < w; w_i++) {
  1217. for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
  1218. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  1219. size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
  1220. // ncdhw
  1221. size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
  1222. auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
  1223. SetData(size, pad_zero, src_i, dst_i, args, result);
  1224. }
  1225. }
  1226. }
  1227. }
  1228. }
  1229. }
  1230. return true;
  1231. }
  1232. bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
  1233. MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
  1234. MS_EXCEPTION_IF_NULL(result);
  1235. if (args.host_shape.size() != 5) {
  1236. MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
  1237. return false;
  1238. }
  1239. auto size = abstract::TypeIdSize(args.src_data_type);
  1240. if (size < 1) {
  1241. MS_LOG(ERROR) << "Illegal dtype.";
  1242. return false;
  1243. }
  1244. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  1245. if (total_size != args.device_size) {
  1246. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  1247. return false;
  1248. }
  1249. auto n = args.host_shape[0];
  1250. auto c = args.host_shape[1];
  1251. auto d = args.host_shape[2];
  1252. auto h = args.host_shape[3];
  1253. auto w = args.host_shape[4];
  1254. auto n0 = args.device_shape[1];
  1255. auto ni = args.device_shape[1];
  1256. auto c0 = args.device_shape[3];
  1257. auto hw = h * w;
  1258. auto dhw = d * hw;
  1259. auto cdhw = c * dhw;
  1260. auto nc = ni * n0;
  1261. auto ncc0 = nc * c0;
  1262. auto wncc0 = w * ncc0;
  1263. auto hwncc0 = h * wncc0;
  1264. auto dhwncc0 = d * hwncc0;
  1265. for (size_t n_i = 0; n_i < n; n_i++) {
  1266. size_t n_head = n_i * cdhw;
  1267. for (size_t c_i = 0; c_i < c; c_i++) {
  1268. size_t c_head = n_head + c_i * dhw;
  1269. for (size_t d_i = 0; d_i < d; d_i++) {
  1270. size_t d_head = c_head + d_i * hw;
  1271. for (size_t h_i = 0; h_i < h; h_i++) {
  1272. size_t h_head = d_head + h_i * w;
  1273. for (size_t w_i = 0; w_i < w; w_i++) {
  1274. size_t dst_i = h_head + w_i;
  1275. size_t c1_i = c_i / c0;
  1276. size_t c0_i = c_i % c0;
  1277. size_t nc_i = n_i;
  1278. size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i;
  1279. SetData(size, false, src_i, dst_i, args, result);
  1280. }
  1281. }
  1282. }
  1283. }
  1284. }
  1285. return true;
  1286. }
  1287. } // namespace trans
  1288. } // namespace mindspore