You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans.cc 41 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/trans.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <utility>
  20. #include "common/utils.h"
  21. #include "session/anf_runtime_algorithm.h"
  22. #include "kernel/kernel.h"
  23. #include "device/convert_tensor_utils.h"
  24. #include "utils/convert_utils.h"
  25. #include "utils/log_adapter.h"
  26. #include "utils/utils.h"
  27. namespace mindspore {
  28. namespace trans {
  29. enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
  30. const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
  31. {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
  32. {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
  33. {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
  34. {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
  35. inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
  36. switch (size) {
  37. case 1:
  38. static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
  39. break;
  40. case 2:
  41. static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
  42. break;
  43. case 4:
  44. static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
  45. break;
  46. case 8:
  47. static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
  48. break;
  49. default:
  50. MS_LOG(EXCEPTION) << "Trans data not support size " << size;
  51. }
  52. }
  53. template <typename T>
  54. T DivCeil(T n1, T n2) {
  55. if (n2 != 0) {
  56. return (n1 - 1) / n2 + 1;
  57. }
  58. return 0;
  59. }
  60. enum DataTypeTransMode {
  61. FROM_FLOAT_TO_FLOAT16,
  62. FROM_FLOAT_TO_INT32,
  63. FROM_FLOAT16_TO_FLOAT,
  64. FROM_FLOAT16_TO_INT32,
  65. FROM_FLOAT16_TO_UINT8,
  66. FROM_INT32_TO_FLOAT,
  67. FROM_INT32_TO_FLOAT16,
  68. FROM_INT32_TO_UINT8,
  69. FROM_INT32_TO_INT8,
  70. FROM_INT32_TO_BOOL,
  71. FROM_UINT8_TO_FLOAT,
  72. FROM_UINT8_TO_INT32,
  73. FROM_UINT8_TO_FLOAT16,
  74. FROM_INT8_TO_FLOAT,
  75. FROM_INT8_TO_FLOAT16,
  76. FROM_INT8_TO_INT32,
  77. FROM_INT64_TO_INT32,
  78. FROM_UINT16_TO_INT32,
  79. FROM_BOOL_TO_FLOAT,
  80. FROM_BOOL_TO_INT32,
  81. FROM_BOOL_TO_UINT8,
  82. FROM_BOOL_TO_FLOAT16,
  83. FROM_FLOAT64_TO_FLOAT32,
  84. FROM_FLOAT32_TO_FLOAT64
  85. };
  86. const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
  87. {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), FROM_FLOAT64_TO_FLOAT32},
  88. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), FROM_FLOAT32_TO_FLOAT64},
  89. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16},
  90. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32},
  91. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT},
  92. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32},
  93. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), FROM_FLOAT16_TO_UINT8},
  94. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT},
  95. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16},
  96. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8},
  97. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8},
  98. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), FROM_INT32_TO_BOOL},
  99. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT},
  100. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32},
  101. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), FROM_UINT8_TO_FLOAT16},
  102. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT},
  103. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), FROM_INT8_TO_FLOAT16},
  104. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32},
  105. {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
  106. {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32},
  107. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), FROM_BOOL_TO_INT32},
  108. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), FROM_BOOL_TO_FLOAT},
  109. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), FROM_BOOL_TO_UINT8},
  110. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}};
  111. void CheckMemSize(const TypeIdArgs &args) {
  112. auto src_type_size = TypeIdSize(args.host_data_type);
  113. auto dst_type_size = TypeIdSize(args.device_data_type);
  114. if (src_type_size < 1 || dst_type_size < 1) {
  115. MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
  116. }
  117. if (args.data_size / src_type_size != args.host_shape_size) {
  118. MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
  119. }
  120. }
  121. template <typename SrcT, typename DstT>
  122. void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
  123. CheckMemSize(args);
  124. for (size_t idx = 0; idx != data_size; idx++) {
  125. SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
  126. static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
  127. }
  128. }
  129. template <typename SrcT>
  130. void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
  131. CheckMemSize(args);
  132. auto src_data = static_cast<const SrcT *>(args.data);
  133. auto half_data = static_cast<Eigen::half *>(dst);
  134. for (size_t i = 0; i < data_size; i++) {
  135. half_data[i] = Eigen::half(src_data[i]);
  136. }
  137. }
  138. bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
  139. using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
  140. const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
  141. {FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
  142. {FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
  143. {FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
  144. {FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
  145. {FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
  146. {FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
  147. {FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
  148. {FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
  149. {FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
  150. {FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
  151. {FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
  152. {FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
  153. {FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
  154. {FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  155. {FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
  156. {FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  157. {FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
  158. {FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
  159. {FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  160. {FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  161. {FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
  162. {FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
  163. if (mode == FROM_FLOAT_TO_FLOAT16) {
  164. device::FloatToHalf(dst, args.data, data_size);
  165. return true;
  166. } else if (mode == FROM_FLOAT16_TO_FLOAT) {
  167. device::HalfToFloat(dst, args.data, data_size);
  168. return true;
  169. }
  170. auto iter = cast_kernel_map.find(mode);
  171. if (iter != cast_kernel_map.end()) {
  172. iter->second(args, dst, data_size);
  173. return true;
  174. } else {
  175. MS_LOG(ERROR) << "Unsupported datatype trans";
  176. return false;
  177. }
  178. }
  179. size_t CubeSizeByType(const TypeId data_type) {
  180. const size_t default_error = 0;
  181. auto dt_size = TypeIdSize(data_type);
  182. if (dt_size < 1) {
  183. MS_LOG(ERROR) << "Illegal dtype.";
  184. return default_error;
  185. } else if (dt_size == 1) {
  186. return kCubeSize * 2;
  187. }
  188. return kCubeSize;
  189. }
  190. size_t ShapeSize(const std::vector<size_t> &shape) {
  191. return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
  192. }
  193. size_t TypeIdSize(const TypeId data_type) {
  194. const size_t unsupported_type_error = 0;
  195. auto iter = type_map.find(data_type);
  196. if (iter != type_map.end()) {
  197. return iter->second;
  198. }
  199. return unsupported_type_error;
  200. }
  201. namespace {
  202. bool CheckDims(const std::vector<size_t> &shape) {
  203. if (shape.size() != kNchwDims) {
  204. MS_LOG(ERROR) << "Host shape dims shoud be 4";
  205. return false;
  206. }
  207. return true;
  208. }
  209. std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
  210. if (!CheckDims(shape)) {
  211. MS_LOG(EXCEPTION) << "Check dims failed.";
  212. }
  213. return shape;
  214. }
  215. std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
  216. if (!CheckDims(shape)) {
  217. MS_LOG(EXCEPTION) << "Ccheck dims failed.";
  218. }
  219. std::vector<size_t> device_shape;
  220. device_shape.push_back(shape[kN]);
  221. device_shape.push_back(shape[kH]);
  222. device_shape.push_back(shape[kW]);
  223. device_shape.push_back(shape[kC]);
  224. return device_shape;
  225. }
  226. std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
  227. if (!CheckDims(shape)) {
  228. MS_LOG(EXCEPTION) << "Check dims failed.";
  229. }
  230. std::vector<size_t> device_shape;
  231. device_shape.push_back(shape[kH]);
  232. device_shape.push_back(shape[kW]);
  233. device_shape.push_back(shape[kC]);
  234. device_shape.push_back(shape[kN]);
  235. return device_shape;
  236. }
  237. std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
  238. if (!CheckDims(shape)) {
  239. MS_LOG(EXCEPTION) << "Check dims failed.";
  240. }
  241. std::vector<size_t> device_shape;
  242. const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  243. const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  244. device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
  245. device_shape.push_back(cout16 / kCubeSize);
  246. device_shape.push_back(kCubeSize);
  247. device_shape.push_back(kCubeSize);
  248. return device_shape;
  249. }
  250. std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
  251. if (!CheckDims(shape)) {
  252. MS_LOG(EXCEPTION) << "Check dims failed.";
  253. }
  254. std::vector<size_t> device_shape;
  255. const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
  256. const size_t C0 = kCubeSize;
  257. device_shape.push_back(shape[kN]);
  258. device_shape.push_back(C1);
  259. device_shape.push_back(shape[kH]);
  260. device_shape.push_back(shape[kW]);
  261. device_shape.push_back(C0);
  262. return device_shape;
  263. }
  264. std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
  265. if (!CheckDims(shape)) {
  266. MS_LOG(EXCEPTION) << "Check dims failed.";
  267. }
  268. std::vector<size_t> device_shape;
  269. device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
  270. device_shape.push_back(shape[kH]);
  271. device_shape.push_back(shape[kW]);
  272. device_shape.push_back(shape[kN]);
  273. device_shape.push_back(kCubeSize);
  274. device_shape.push_back(kCubeSize);
  275. return device_shape;
  276. }
  277. std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
  278. if (!CheckDims(shape)) {
  279. MS_LOG(EXCEPTION) << "Check dims failed.";
  280. }
  281. std::vector<size_t> device_shape;
  282. const size_t c0 = 4;
  283. auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
  284. auto no = DivCeil(shape.at(kN), kCubeSize);
  285. device_shape.push_back(first_dim);
  286. device_shape.push_back(no);
  287. device_shape.push_back(kCubeSize);
  288. device_shape.push_back(kCubeSize);
  289. return device_shape;
  290. }
  291. std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
  292. if (!CheckDims(shape)) {
  293. MS_LOG(EXCEPTION) << "Check dims failed.";
  294. }
  295. std::vector<size_t> device_shape;
  296. const size_t C1 = 1;
  297. const size_t C0 = 4;
  298. device_shape.push_back(shape[kN]);
  299. device_shape.push_back(C1);
  300. device_shape.push_back(shape[kH]);
  301. device_shape.push_back(shape[kW]);
  302. device_shape.push_back(C0);
  303. return device_shape;
  304. }
  305. std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
  306. if (shape.size() < kNdhwc) {
  307. MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
  308. }
  309. return shape;
  310. }
  311. std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
  312. std::vector<size_t> shape_4d(kNchwDims, 1);
  313. switch (shape.size()) {
  314. case 0:
  315. return shape_4d;
  316. case 1:
  317. shape_4d[kC] = shape[kN];
  318. break;
  319. case 2:
  320. shape_4d[kC] = shape[kN];
  321. shape_4d[kH] = shape[kC];
  322. break;
  323. case 3:
  324. shape_4d[kC] = shape[kN];
  325. shape_4d[kH] = shape[kC];
  326. shape_4d[kW] = shape[kH];
  327. break;
  328. case 4:
  329. std::copy(shape.begin(), shape.end(), shape_4d.begin());
  330. break;
  331. default:
  332. MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
  333. }
  334. return shape_4d;
  335. }
  336. } // namespace
  337. bool IsNeedPadding(const std::string &format, const size_t shape_size) {
  338. if (shape_size == 0) {
  339. return false;
  340. }
  341. if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
  342. return false;
  343. } else if (shape_size < kNchwDims) {
  344. return true;
  345. }
  346. return false;
  347. }
  348. std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
  349. MS_EXCEPTION_IF_NULL(node);
  350. std::vector<int> shape;
  351. std::vector<size_t> host_shape;
  352. if (node->isa<ValueNode>()) {
  353. auto value_node = node->cast<ValueNodePtr>();
  354. MS_EXCEPTION_IF_NULL(value_node);
  355. auto node_value = value_node->value();
  356. MS_EXCEPTION_IF_NULL(node_value);
  357. auto tensor = node_value->cast<tensor::TensorPtr>();
  358. if (tensor == nullptr) {
  359. MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
  360. }
  361. auto shape_temp = tensor->shape();
  362. (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize);
  363. if (host_shape.empty()) {
  364. host_shape.push_back(1);
  365. }
  366. } else {
  367. host_shape = AnfAlgo::GetOutputInferShape(node, index);
  368. }
  369. if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
  370. host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
  371. }
  372. std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
  373. return shape;
  374. }
  375. std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
  376. if (padding_axis.empty() || shape.size() != padding_axis.size()) {
  377. return PaddingShapeTo4dByDefault(shape);
  378. }
  379. std::vector<size_t> shape_4d(kNchwDims, 1);
  380. for (size_t index = 0; index < padding_axis.size(); index++) {
  381. shape_4d[padding_axis[index]] = shape[index];
  382. }
  383. return shape_4d;
  384. }
  385. std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
  386. using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
  387. const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
  388. {kOpFormat_NHWC, NhwcDeviceShape},
  389. {kOpFormat_HWCN, HwchDeviceShape},
  390. {kOpFormat_FRAC_Z, FracZDeviceShape},
  391. {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
  392. {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
  393. {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
  394. {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
  395. {kOpFormat_NDHWC, NdhwcDeviceShape}};
  396. if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
  397. return shape;
  398. }
  399. auto temp_shape = shape;
  400. std::vector<size_t> device_shape;
  401. if (format == kOpFormat_FRAC_NZ) {
  402. if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
  403. // For [1] and [1024] shape we can trait it as NZ shape
  404. return shape;
  405. }
  406. if (shape.size() < 2) {
  407. MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
  408. } else {
  409. (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
  410. }
  411. auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
  412. auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
  413. device_shape.push_back(w1);
  414. device_shape.push_back(h1);
  415. device_shape.push_back(kCubeSize);
  416. device_shape.push_back(kCubeSize);
  417. return device_shape;
  418. }
  419. if (shape.size() != kNchwDims) {
  420. MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
  421. temp_shape = PaddingShapeTo4dByDefault(shape);
  422. }
  423. auto iter = device_shape_map.find(format);
  424. if (iter == device_shape_map.end()) {
  425. MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
  426. }
  427. return iter->second(temp_shape);
  428. }
  429. bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
  430. if (args.host_shape.size() != kNchwDims) {
  431. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  432. return false;
  433. }
  434. MS_EXCEPTION_IF_NULL(size);
  435. MS_EXCEPTION_IF_NULL(total_size);
  436. *size = TypeIdSize(args.src_data_type);
  437. if (*size < 1) {
  438. MS_LOG(ERROR) << "Illegal dtype.";
  439. return false;
  440. }
  441. *total_size = ShapeSize(args.device_shape) * (*size);
  442. if (*total_size != args.device_size) {
  443. MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
  444. return false;
  445. }
  446. return true;
  447. }
  448. bool TransDataType(const TypeIdArgs &args, void *result) {
  449. MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
  450. << TypeIdLabel(args.device_data_type);
  451. MS_EXCEPTION_IF_NULL(result);
  452. std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
  453. auto iter = mode_map.find(type_info);
  454. if (iter == mode_map.end()) {
  455. MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
  456. << ", dst_type:" << TypeIdLabel(args.device_data_type);
  457. return false;
  458. }
  459. auto trans_mode = iter->second;
  460. if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
  461. MS_LOG(ERROR) << "Failed to trans datatype..";
  462. return false;
  463. }
  464. return true;
  465. }
  466. bool TransFormat(const FormatArgs &args, void *result) {
  467. using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
  468. const std::map<std::string, FormatTransfer> format_trans_map{
  469. {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
  470. {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
  471. {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
  472. MS_LOG(DEBUG) << "Start trans format.";
  473. if (TypeIdSize(args.src_data_type) < 1) {
  474. MS_LOG(ERROR) << "Invalid datatype..";
  475. return false;
  476. }
  477. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  478. return NchwTo4D(args, result);
  479. }
  480. auto iter = format_trans_map.find(args.device_format);
  481. if (iter == format_trans_map.end()) {
  482. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  483. }
  484. return iter->second(args, result);
  485. }
  486. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
  487. using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
  488. const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw},
  489. {kOpFormat_FRAC_NZ, FracNzToNchw},
  490. {kOpFormat_NC1HWC0, Nc1hwc0ToNchw},
  491. {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
  492. {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
  493. MS_LOG(DEBUG) << "Start trans format.";
  494. if (TypeIdSize(args.src_data_type) < 1) {
  495. MS_LOG(ERROR) << "Invalid datatype..";
  496. return false;
  497. }
  498. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  499. return ToNchw(args, result);
  500. }
  501. auto iter = format_trans_map.find(args.device_format);
  502. if (iter == format_trans_map.end()) {
  503. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  504. }
  505. return iter->second(args, result);
  506. }
  507. bool NchwTo4D(const FormatArgs &args, void *result) {
  508. // trans nchw to 4d
  509. MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
  510. MS_EXCEPTION_IF_NULL(result);
  511. size_t size = 0;
  512. size_t total_size = 0;
  513. if (!CheckArgs(args, &size, &total_size)) {
  514. MS_LOG(ERROR) << "Check args failed.";
  515. return false;
  516. }
  517. auto n = args.host_shape[kN];
  518. auto c = args.host_shape[kC];
  519. auto h = args.host_shape[kH];
  520. auto w = args.host_shape[kW];
  521. for (size_t ni = 0; ni < n; ni++) {
  522. for (size_t ci = 0; ci < c; ci++) {
  523. for (size_t hi = 0; hi < h; hi++) {
  524. for (size_t wi = 0; wi < w; wi++) {
  525. auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  526. auto dst_idx = 0;
  527. if (args.device_format == kOpFormat_NHWC) {
  528. dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  529. } else if (args.device_format == kOpFormat_HWCN) {
  530. dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  531. }
  532. SetData(size, false, src_idx, dst_idx, args, result);
  533. }
  534. }
  535. }
  536. }
  537. return true;
  538. }
  539. bool ToNchw(const FormatArgs &args, void *result) {
  540. MS_LOG(DEBUG) << "Trans format to nchw from 4d.";
  541. MS_EXCEPTION_IF_NULL(result);
  542. size_t size = 0;
  543. size_t total_size = 0;
  544. if (!CheckArgs(args, &size, &total_size)) {
  545. MS_LOG(ERROR) << "Check args failed.";
  546. return false;
  547. }
  548. auto n = args.host_shape[kN];
  549. auto c = args.host_shape[kC];
  550. auto h = args.host_shape[kH];
  551. auto w = args.host_shape[kW];
  552. for (size_t ni = 0; ni < n; ni++) {
  553. for (size_t ci = 0; ci < c; ci++) {
  554. for (size_t hi = 0; hi < h; hi++) {
  555. for (size_t wi = 0; wi < w; wi++) {
  556. auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  557. auto src_idx = 0;
  558. if (args.device_format == kOpFormat_NHWC) {
  559. src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  560. } else if (args.device_format == kOpFormat_HWCN) {
  561. src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  562. }
  563. SetData(size, false, src_idx, dst_idx, args, result);
  564. }
  565. }
  566. }
  567. }
  568. return true;
  569. }
  570. bool NchwToFracZ(const FormatArgs &args, void *result) {
  571. MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
  572. MS_EXCEPTION_IF_NULL(result);
  573. if (args.host_shape.size() != kNchwDims) {
  574. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  575. return false;
  576. }
  577. auto size = TypeIdSize(args.src_data_type);
  578. if (size < 1) {
  579. MS_LOG(ERROR) << "Illegal dtype.";
  580. return false;
  581. }
  582. auto n = args.host_shape[kN];
  583. auto c = args.host_shape[kC];
  584. auto h = args.host_shape[kH];
  585. auto w = args.host_shape[kW];
  586. auto c0 = CubeSizeByType(args.src_data_type);
  587. if (c0 < 1) {
  588. MS_LOG(ERROR) << "Illegal dtype.";
  589. return false;
  590. }
  591. auto c1 = DivCeil(c, c0);
  592. auto hw = h * w;
  593. auto chw = c * hw;
  594. auto hwc0 = hw * c0;
  595. auto nchw = n * chw;
  596. auto hf_cnt = DivCeil(n, kCubeSize);
  597. auto vf_cnt = c1 * hw;
  598. auto fractal_ele_cnt = c0 * kCubeSize;
  599. auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
  600. auto dst_size = total_ele_cnt * size;
  601. if (dst_size != args.device_size) {
  602. MS_LOG(ERROR) << "Illegal total data size."
  603. << "dst size is :" << dst_size << "device size is :" << args.device_size;
  604. return false;
  605. }
  606. for (size_t vfi = 0; vfi < vf_cnt; vfi++) {
  607. auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
  608. for (size_t hfi = 0; hfi < hf_cnt; hfi++) {
  609. auto gfi = vf_base_i + hfi; // global fractal matrix index
  610. auto src_n_offset = hfi * chw * kCubeSize;
  611. auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
  612. for (size_t row = 0; row < c0; row++) {
  613. auto src_ci = vfi / hw * c0 + row;
  614. auto src_row_offset = src_f_offset + row * hw;
  615. for (size_t col = 0; col < kCubeSize; col++) {
  616. auto src_ni = hfi * kCubeSize + col;
  617. auto src_idx = src_row_offset + chw * col;
  618. auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
  619. auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
  620. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  621. }
  622. }
  623. }
  624. }
  625. return true;
  626. }
  627. bool FracZToNchw(const FormatArgs &args, void *result) {
  628. MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
  629. MS_EXCEPTION_IF_NULL(result);
  630. if (args.host_shape.size() != kNchwDims) {
  631. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  632. return false;
  633. }
  634. auto size = TypeIdSize(args.src_data_type);
  635. if (size < 1) {
  636. MS_LOG(ERROR) << "Illegal dtype.";
  637. return false;
  638. }
  639. auto total_size = ShapeSize(args.device_shape) * size;
  640. if (total_size != args.device_size) {
  641. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  642. return false;
  643. }
  644. auto n0 = args.device_shape.at(1);
  645. auto ni = args.device_shape.at(2);
  646. auto c0 = args.device_shape.at(3);
  647. auto n = args.host_shape[kN];
  648. auto c = args.host_shape[kC];
  649. auto h = args.host_shape[kH];
  650. auto w = args.host_shape[kW];
  651. auto nc = ni * n0;
  652. auto ncc0 = nc * c0;
  653. auto wncc0 = w * ncc0;
  654. auto hwncc0 = h * wncc0;
  655. auto hw = h * w;
  656. auto chw = c * hw;
  657. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  658. size_t n_head_addr = n_idx * chw;
  659. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  660. size_t c_head_addr = n_head_addr + c_idx * hw;
  661. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  662. size_t h_head_addr = c_head_addr + h_idx * w;
  663. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  664. size_t dst_idx = h_head_addr + w_idx;
  665. size_t c1_idx = c_idx / c0;
  666. size_t c0_idx = c_idx % c0;
  667. size_t nc_idx = n_idx;
  668. size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
  669. SetData(size, false, src_idx, dst_idx, args, result);
  670. }
  671. }
  672. }
  673. }
  674. return true;
  675. }
  676. bool NchwToFracZc04(const FormatArgs &args, void *result) {
  677. // trans nchw to FracZc04
  678. MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
  679. MS_EXCEPTION_IF_NULL(result);
  680. size_t size = 0;
  681. size_t total_size = 0;
  682. if (!CheckArgs(args, &size, &total_size)) {
  683. MS_LOG(ERROR) << "Check args failed.";
  684. return false;
  685. }
  686. auto cube = kCubeSize;
  687. auto n = args.host_shape[kN];
  688. auto c = args.host_shape[kC];
  689. auto h = args.host_shape[kH];
  690. auto w = args.host_shape[kW];
  691. const size_t c0 = 4;
  692. auto c1 = DivCeil(c, c0);
  693. auto hwc0 = h * w * c0;
  694. auto hwc = h * w * c;
  695. auto nhwc = n * h * w * c;
  696. auto n_cnt = DivCeil(n, cube);
  697. auto v_cnt = DivCeil(h * w * c0 * c1, cube);
  698. size_t dst_idx = 0;
  699. for (size_t vi = 0; vi < v_cnt; vi++) {
  700. for (size_t ni = 0; ni < n_cnt; ni++) {
  701. for (size_t col = 0; col < cube; col++) {
  702. for (size_t row = 0; row < cube; row++) {
  703. size_t cur_cube_n = cube * ni + col;
  704. size_t cur_cube_c1hwc0 = cube * vi + row;
  705. auto desc_g = cur_cube_n / n;
  706. auto desc_n = cur_cube_n % n;
  707. auto desc_c1 = cur_cube_c1hwc0 / hwc0;
  708. auto desc_c0 = cur_cube_c1hwc0 % c0;
  709. auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
  710. auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
  711. auto c_idx = desc_c1 * c0 + desc_c0;
  712. auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
  713. auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
  714. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  715. dst_idx++;
  716. }
  717. }
  718. }
  719. }
  720. return true;
  721. }
  722. bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
  723. MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
  724. return NchwToNc1hwc0(args, result);
  725. }
  726. bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
  727. MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
  728. return Nc1hwc0ToNchw(args, result);
  729. }
  730. bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
  731. MS_EXCEPTION_IF_NULL(hw_shape);
  732. if (host_shape.empty()) {
  733. MS_LOG(ERROR) << "Size of vector is 0.";
  734. return false;
  735. }
  736. switch (host_shape.size()) {
  737. case 1:
  738. hw_shape->push_back(1);
  739. hw_shape->push_back(1);
  740. hw_shape->push_back(host_shape[0]);
  741. return true;
  742. default:
  743. auto size = host_shape.size();
  744. if (size < 2) {
  745. MS_LOG(ERROR) << "Illegal size.";
  746. return false;
  747. }
  748. size_t times = 1;
  749. for (size_t i = 0; i != size - 2; i++) {
  750. times *= host_shape[i];
  751. }
  752. hw_shape->push_back(times);
  753. hw_shape->push_back(host_shape[size - 2]);
  754. hw_shape->push_back(host_shape[size - 1]);
  755. return true;
  756. }
  757. }
  758. bool NchwToFracNz(const FormatArgs &args, void *result) {
  759. MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
  760. MS_EXCEPTION_IF_NULL(result);
  761. std::vector<size_t> hw_shape;
  762. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  763. MS_LOG(ERROR) << "Trans shape failed..";
  764. return false;
  765. }
  766. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  767. MS_LOG(ERROR) << "Invalid shape size.";
  768. return false;
  769. }
  770. auto size = TypeIdSize(args.src_data_type);
  771. if (size < 1) {
  772. MS_LOG(ERROR) << "Illegal dtype";
  773. return false;
  774. }
  775. auto dst_size = ShapeSize(args.device_shape) * size;
  776. if (dst_size != args.device_size) {
  777. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  778. return false;
  779. }
  780. auto times = hw_shape.at(0);
  781. auto h = hw_shape.at(1);
  782. auto w = hw_shape.at(2);
  783. auto hw = h * w;
  784. auto shape_size = args.device_shape.size();
  785. auto w1 = args.device_shape[shape_size - 4];
  786. auto h1 = args.device_shape[shape_size - 3];
  787. auto h0 = args.device_shape[shape_size - 2];
  788. auto w0 = args.device_shape[shape_size - 1];
  789. auto h1h0w0 = h1 * h0 * w0;
  790. auto w1h1h0w0 = w1 * h1h0w0;
  791. auto num_w1 = w / w0;
  792. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  793. auto times_head = times_idx * w1h1h0w0;
  794. auto src_times_head = times_idx * hw;
  795. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  796. auto h1h0_head = times_head + h1h0_idx * w0;
  797. auto src_h_head = src_times_head + h1h0_idx * w;
  798. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  799. for (size_t i = 0; i < w0; ++i) {
  800. size_t src_idx = src_h_head + w1_idx * w0 + i;
  801. size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
  802. SetData(size, false, src_idx, dst_idx, args, result);
  803. }
  804. }
  805. auto w1_head = num_w1 * w0;
  806. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  807. auto src_w_idx = w1_head + w0_idx;
  808. size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  809. size_t src_idx = src_h_head + src_w_idx;
  810. SetData(size, false, src_idx, dst_idx, args, result);
  811. }
  812. }
  813. }
  814. return true;
  815. }
  816. bool FracNzToNchw(const FormatArgs &args, void *result) {
  817. MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
  818. MS_EXCEPTION_IF_NULL(result);
  819. std::vector<size_t> hw_shape;
  820. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  821. MS_LOG(ERROR) << "Trans shape failed..";
  822. return false;
  823. }
  824. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  825. MS_LOG(ERROR) << "Invalid shape size.";
  826. return false;
  827. }
  828. auto size = TypeIdSize(args.src_data_type);
  829. if (size < 1) {
  830. MS_LOG(ERROR) << "Illegal dtype";
  831. return false;
  832. }
  833. auto dst_size = ShapeSize(args.device_shape) * size;
  834. if (dst_size != args.device_size) {
  835. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  836. return false;
  837. }
  838. auto times = hw_shape.at(0);
  839. auto h = hw_shape.at(1);
  840. auto w = hw_shape.at(2);
  841. auto hw = h * w;
  842. auto shape_size = args.device_shape.size();
  843. auto w1 = args.device_shape[shape_size - 4];
  844. auto h1 = args.device_shape[shape_size - 3];
  845. auto h0 = args.device_shape[shape_size - 2];
  846. auto w0 = args.device_shape[shape_size - 1];
  847. auto h1h0w0 = h1 * h0 * w0;
  848. auto w1h1h0w0 = w1 * h1h0w0;
  849. auto num_w1 = w / w0;
  850. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  851. auto times_head = times_idx * w1h1h0w0;
  852. auto src_times_head = times_idx * hw;
  853. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  854. auto h1h0_head = times_head + h1h0_idx * w0;
  855. auto src_h_head = src_times_head + h1h0_idx * w;
  856. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  857. for (size_t i = 0; i < w0; ++i) {
  858. size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
  859. size_t dst_idx = src_h_head + w1_idx * w0 + i;
  860. SetData(size, false, src_idx, dst_idx, args, result);
  861. }
  862. }
  863. auto w1_head = num_w1 * w0;
  864. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  865. auto src_w_idx = w1_head + w0_idx;
  866. size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  867. size_t dst_idx = src_h_head + src_w_idx;
  868. SetData(size, false, src_idx, dst_idx, args, result);
  869. }
  870. }
  871. }
  872. return true;
  873. }
  874. bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
  875. MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
  876. MS_EXCEPTION_IF_NULL(result);
  877. if (args.host_shape.size() != kNchwDims) {
  878. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  879. return false;
  880. }
  881. auto size = TypeIdSize(args.src_data_type);
  882. if (size < 1) {
  883. MS_LOG(ERROR) << "Illegal dtype.";
  884. return false;
  885. }
  886. auto total_size = ShapeSize(args.device_shape) * size;
  887. if (total_size != args.device_size) {
  888. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  889. return false;
  890. }
  891. auto n = args.host_shape[kN];
  892. auto c = args.host_shape[kC];
  893. auto h = args.host_shape[kH];
  894. auto w = args.host_shape[kW];
  895. auto c0 = CubeSizeByType(args.src_data_type);
  896. if (c0 < 1) {
  897. MS_LOG(ERROR) << "Illegal dtype.";
  898. return false;
  899. }
  900. if (args.device_format == kOpFormat_NC1HWC0_C04) {
  901. c0 = 4;
  902. }
  903. auto c1 = DivCeil(c, c0);
  904. auto hw = h * w;
  905. auto chw = c * hw;
  906. auto c1hwc0 = c1 * hw * c0;
  907. auto wc0 = w * c0;
  908. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  909. size_t n_head_addr = n_idx * c1hwc0;
  910. for (size_t c1_idx = 0; c1_idx < c1; c1_idx++) {
  911. size_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
  912. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  913. size_t h_head_addr = c1_head_addr + h_idx * wc0;
  914. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  915. size_t w_head_addr = h_head_addr + w_idx * c0;
  916. for (size_t c0_idx = 0; c0_idx < c0; c0_idx++) {
  917. size_t dst_idx = c0_idx + w_head_addr;
  918. size_t c_idx = c0_idx + c1_idx * c0;
  919. size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
  920. auto pad_zero = c_idx >= c;
  921. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  922. }
  923. }
  924. }
  925. }
  926. }
  927. return true;
  928. }
  929. bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
  930. MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
  931. MS_EXCEPTION_IF_NULL(result);
  932. if (args.host_shape.size() != kNchwDims) {
  933. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  934. return false;
  935. }
  936. auto size = TypeIdSize(args.src_data_type);
  937. if (size < 1) {
  938. MS_LOG(ERROR) << "Illegal dtype.";
  939. return false;
  940. }
  941. auto total_size = ShapeSize(args.device_shape) * size;
  942. if (total_size != args.device_size) {
  943. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  944. return false;
  945. }
  946. auto n = args.host_shape[kN];
  947. auto c = args.host_shape[kC];
  948. auto h = args.host_shape[kH];
  949. auto w = args.host_shape[kW];
  950. auto c1 = args.device_shape[1];
  951. auto c0 = args.device_shape[4];
  952. auto hw = h * w;
  953. auto chw = c * hw;
  954. auto wc0 = w * c0;
  955. auto hwc0 = h * wc0;
  956. auto c1hwc0 = c1 * hwc0;
  957. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  958. size_t n_head_addr = n_idx * chw;
  959. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  960. size_t c_head_addr = n_head_addr + c_idx * hw;
  961. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  962. size_t h_head_addr = c_head_addr + h_idx * w;
  963. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  964. size_t dst_idx = h_head_addr + w_idx;
  965. size_t c1_idx = c_idx / c0;
  966. size_t c0_idx = c_idx % c0;
  967. size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
  968. SetData(size, false, src_idx, dst_idx, args, result);
  969. }
  970. }
  971. }
  972. }
  973. return true;
  974. }
  975. bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
  976. // trans nchw to c1hwncoc0
  977. MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
  978. MS_EXCEPTION_IF_NULL(result);
  979. size_t size = 0;
  980. size_t total_size = 0;
  981. if (!CheckArgs(args, &size, &total_size)) {
  982. MS_LOG(ERROR) << "Check args failed.";
  983. return false;
  984. }
  985. auto n = args.host_shape[kN];
  986. auto c = args.host_shape[kC];
  987. auto h = args.host_shape[kH];
  988. auto w = args.host_shape[kW];
  989. const int co_idx = 4;
  990. const int c0_idx = 5;
  991. auto c1 = args.device_shape[0];
  992. auto co = args.device_shape[co_idx];
  993. auto c0 = args.device_shape[c0_idx];
  994. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  995. for (size_t h_i = 0; h_i < h; h_i++) {
  996. for (size_t w_i = 0; w_i < w; w_i++) {
  997. for (size_t n_i = 0; n_i < n; n_i++) {
  998. for (size_t co_i = 0; co_i < co; co_i++) {
  999. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  1000. size_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
  1001. co_i * c0 + c0_i;
  1002. size_t c_i = c0_i + c1_i * c0;
  1003. size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1004. auto pad_zero = !(c_i < c && c0_i == co_i);
  1005. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1006. }
  1007. }
  1008. }
  1009. }
  1010. }
  1011. }
  1012. return true;
  1013. }
  1014. bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
  1015. // trans c1hwncoc0 to nchw
  1016. MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
  1017. MS_EXCEPTION_IF_NULL(result);
  1018. size_t size = 0;
  1019. size_t total_size = 0;
  1020. if (!CheckArgs(args, &size, &total_size)) {
  1021. MS_LOG(ERROR) << "Check args failed.";
  1022. return false;
  1023. }
  1024. auto n = args.host_shape[kN];
  1025. auto c = args.host_shape[kC];
  1026. auto h = args.host_shape[kH];
  1027. auto w = args.host_shape[kW];
  1028. const int co_idx = 4;
  1029. const int c0_idx = 5;
  1030. auto co = args.device_shape[co_idx];
  1031. auto c0 = args.device_shape[c0_idx];
  1032. for (size_t n_i = 0; n_i < n; n_i++) {
  1033. for (size_t c_i = 0; c_i < c; c_i++) {
  1034. for (size_t h_i = 0; h_i < h; h_i++) {
  1035. for (size_t w_i = 0; w_i < w; w_i++) {
  1036. size_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1037. size_t c1_i = c_i / kCubeSize;
  1038. size_t c0_i = c_i % kCubeSize;
  1039. size_t co_i = c0_i;
  1040. size_t src_idx =
  1041. c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
  1042. SetData(size, false, src_idx, dst_idx, args, result);
  1043. }
  1044. }
  1045. }
  1046. }
  1047. return true;
  1048. }
  1049. } // namespace trans
  1050. } // namespace mindspore