You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans.cc 41 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/trans.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <utility>
  20. #include "utils/ms_utils.h"
  21. #include "abstract/utils.h"
  22. #include "backend/session/anf_runtime_algorithm.h"
  23. #include "backend/kernel_compiler/kernel.h"
  24. #include "runtime/device/convert_tensor_utils.h"
  25. #include "utils/convert_utils.h"
  26. #include "utils/log_adapter.h"
  27. #include "utils/utils.h"
  28. namespace mindspore {
  29. namespace trans {
  30. enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
  31. inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
  32. switch (size) {
  33. case 1:
  34. static_cast<uint8_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint8_t *>(args.data)[src_idx];
  35. break;
  36. case 2:
  37. static_cast<uint16_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint16_t *>(args.data)[src_idx];
  38. break;
  39. case 4:
  40. static_cast<uint32_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint32_t *>(args.data)[src_idx];
  41. break;
  42. case 8:
  43. static_cast<uint64_t *>(result)[dst_idx] = pad_zero ? 0 : static_cast<const uint64_t *>(args.data)[src_idx];
  44. break;
  45. default:
  46. MS_LOG(EXCEPTION) << "Trans data not support size " << size;
  47. }
  48. }
  49. template <typename T>
  50. T DivCeil(T n1, T n2) {
  51. if (n2 != 0) {
  52. return (n1 + n2 - 1) / n2;
  53. }
  54. return 0;
  55. }
  56. enum DataTypeTransMode {
  57. FROM_FLOAT_TO_FLOAT16,
  58. FROM_FLOAT_TO_INT32,
  59. FROM_FLOAT16_TO_FLOAT,
  60. FROM_FLOAT16_TO_INT32,
  61. FROM_FLOAT16_TO_UINT8,
  62. FROM_INT32_TO_FLOAT,
  63. FROM_INT32_TO_FLOAT16,
  64. FROM_INT32_TO_UINT8,
  65. FROM_INT32_TO_INT8,
  66. FROM_INT32_TO_INT64,
  67. FROM_INT32_TO_BOOL,
  68. FROM_UINT8_TO_FLOAT,
  69. FROM_UINT8_TO_INT32,
  70. FROM_UINT8_TO_FLOAT16,
  71. FROM_INT8_TO_FLOAT,
  72. FROM_INT8_TO_FLOAT16,
  73. FROM_INT8_TO_INT32,
  74. FROM_INT64_TO_INT32,
  75. FROM_UINT16_TO_INT32,
  76. FROM_BOOL_TO_FLOAT,
  77. FROM_BOOL_TO_INT32,
  78. FROM_BOOL_TO_UINT8,
  79. FROM_BOOL_TO_FLOAT16,
  80. FROM_FLOAT64_TO_FLOAT32,
  81. FROM_FLOAT32_TO_FLOAT64
  82. };
  83. const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
  84. {std::pair<TypeId, TypeId>(kNumberTypeFloat64, kNumberTypeFloat32), FROM_FLOAT64_TO_FLOAT32},
  85. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat64), FROM_FLOAT32_TO_FLOAT64},
  86. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16},
  87. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32},
  88. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT},
  89. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32},
  90. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeUInt8), FROM_FLOAT16_TO_UINT8},
  91. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT},
  92. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16},
  93. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8},
  94. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8},
  95. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt64), FROM_INT32_TO_INT64},
  96. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeBool), FROM_INT32_TO_BOOL},
  97. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT},
  98. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32},
  99. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat16), FROM_UINT8_TO_FLOAT16},
  100. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT},
  101. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat16), FROM_INT8_TO_FLOAT16},
  102. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32},
  103. {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
  104. {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32},
  105. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeInt32), FROM_BOOL_TO_INT32},
  106. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat), FROM_BOOL_TO_FLOAT},
  107. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeUInt8), FROM_BOOL_TO_UINT8},
  108. {std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}};
  109. void CheckMemSize(const TypeIdArgs &args) {
  110. auto src_type_size = abstract::TypeIdSize(args.host_data_type);
  111. auto dst_type_size = abstract::TypeIdSize(args.device_data_type);
  112. if (src_type_size < 1 || dst_type_size < 1) {
  113. MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
  114. }
  115. if (args.data_size / src_type_size != args.host_shape_size) {
  116. MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
  117. }
  118. }
  119. template <typename SrcT, typename DstT>
  120. void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
  121. CheckMemSize(args);
  122. for (size_t idx = 0; idx != data_size; idx++) {
  123. SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
  124. static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
  125. }
  126. }
  127. template <typename SrcT>
  128. void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
  129. CheckMemSize(args);
  130. auto src_data = static_cast<const SrcT *>(args.data);
  131. auto half_data = static_cast<float16 *>(dst);
  132. for (size_t i = 0; i < data_size; i++) {
  133. half_data[i] = float16(src_data[i]);
  134. }
  135. }
  136. bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
  137. using DtypeKernel = std::function<void(const TypeIdArgs &, void *, const size_t)>;
  138. const std::map<DataTypeTransMode, DtypeKernel> cast_kernel_map{
  139. {FROM_FLOAT_TO_INT32, TransDataSrc2Dst<float, int32_t>},
  140. {FROM_FLOAT64_TO_FLOAT32, TransDataSrc2Dst<double, float>},
  141. {FROM_FLOAT32_TO_FLOAT64, TransDataSrc2Dst<float, double>},
  142. {FROM_FLOAT16_TO_INT32, TransDataSrc2Dst<float16, int32_t>},
  143. {FROM_FLOAT16_TO_UINT8, TransDataSrc2Dst<float16, uint8_t>},
  144. {FROM_INT32_TO_FLOAT, TransDataSrc2Dst<int32_t, float>},
  145. {FROM_INT32_TO_INT8, TransDataSrc2Dst<int32_t, int8_t>},
  146. {FROM_INT32_TO_INT64, TransDataSrc2Dst<int32_t, int64_t>},
  147. {FROM_INT32_TO_UINT8, TransDataSrc2Dst<int32_t, uint8_t>},
  148. {FROM_INT32_TO_BOOL, TransDataSrc2Dst<int32_t, int8_t>},
  149. {FROM_INT32_TO_FLOAT16, TransDataSrc2Fp16<int32_t>},
  150. {FROM_UINT8_TO_FLOAT, TransDataSrc2Dst<uint8_t, float>},
  151. {FROM_UINT8_TO_INT32, TransDataSrc2Dst<uint8_t, int32_t>},
  152. {FROM_UINT8_TO_FLOAT16, TransDataSrc2Fp16<uint8_t>},
  153. {FROM_INT8_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  154. {FROM_INT8_TO_FLOAT16, TransDataSrc2Fp16<int8_t>},
  155. {FROM_INT8_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  156. {FROM_INT64_TO_INT32, TransDataSrc2Dst<int64_t, int32_t>},
  157. {FROM_UINT16_TO_INT32, TransDataSrc2Dst<uint16_t, int32_t>},
  158. {FROM_BOOL_TO_INT32, TransDataSrc2Dst<int8_t, int32_t>},
  159. {FROM_BOOL_TO_FLOAT, TransDataSrc2Dst<int8_t, float>},
  160. {FROM_BOOL_TO_UINT8, TransDataSrc2Dst<int8_t, uint8_t>},
  161. {FROM_BOOL_TO_FLOAT16, TransDataSrc2Fp16<int8_t>}};
  162. if (mode == FROM_FLOAT_TO_FLOAT16) {
  163. device::FloatToHalf(dst, args.data, data_size);
  164. return true;
  165. } else if (mode == FROM_FLOAT16_TO_FLOAT) {
  166. device::HalfToFloat(dst, args.data, data_size);
  167. return true;
  168. }
  169. auto iter = cast_kernel_map.find(mode);
  170. if (iter != cast_kernel_map.end()) {
  171. iter->second(args, dst, data_size);
  172. return true;
  173. } else {
  174. MS_LOG(ERROR) << "Unsupported datatype trans";
  175. return false;
  176. }
  177. }
  178. size_t CubeSizeByType(const TypeId data_type) {
  179. const size_t default_error = 0;
  180. auto dt_size = abstract::TypeIdSize(data_type);
  181. if (dt_size < 1) {
  182. MS_LOG(ERROR) << "Illegal dtype.";
  183. return default_error;
  184. } else if (dt_size == 1) {
  185. return kCubeSize * 2;
  186. }
  187. return kCubeSize;
  188. }
  189. namespace {
  190. bool CheckDims(const std::vector<size_t> &shape) {
  191. if (shape.size() != kNchwDims) {
  192. MS_LOG(ERROR) << "Host shape dims shoud be 4";
  193. return false;
  194. }
  195. return true;
  196. }
  197. std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
  198. if (!CheckDims(shape)) {
  199. MS_LOG(EXCEPTION) << "Check dims failed.";
  200. }
  201. return shape;
  202. }
  203. std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
  204. if (!CheckDims(shape)) {
  205. MS_LOG(EXCEPTION) << "Ccheck dims failed.";
  206. }
  207. std::vector<size_t> device_shape;
  208. device_shape.push_back(shape[kN]);
  209. device_shape.push_back(shape[kH]);
  210. device_shape.push_back(shape[kW]);
  211. device_shape.push_back(shape[kC]);
  212. return device_shape;
  213. }
  214. std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
  215. if (!CheckDims(shape)) {
  216. MS_LOG(EXCEPTION) << "Check dims failed.";
  217. }
  218. std::vector<size_t> device_shape;
  219. device_shape.push_back(shape[kH]);
  220. device_shape.push_back(shape[kW]);
  221. device_shape.push_back(shape[kC]);
  222. device_shape.push_back(shape[kN]);
  223. return device_shape;
  224. }
  225. std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
  226. if (!CheckDims(shape)) {
  227. MS_LOG(EXCEPTION) << "Check dims failed.";
  228. }
  229. std::vector<size_t> device_shape;
  230. const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  231. const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  232. device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
  233. device_shape.push_back(cout16 / kCubeSize);
  234. device_shape.push_back(kCubeSize);
  235. device_shape.push_back(kCubeSize);
  236. return device_shape;
  237. }
  238. std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
  239. if (!CheckDims(shape)) {
  240. MS_LOG(EXCEPTION) << "Check dims failed.";
  241. }
  242. std::vector<size_t> device_shape;
  243. const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
  244. const size_t C0 = kCubeSize;
  245. device_shape.push_back(shape[kN]);
  246. device_shape.push_back(C1);
  247. device_shape.push_back(shape[kH]);
  248. device_shape.push_back(shape[kW]);
  249. device_shape.push_back(C0);
  250. return device_shape;
  251. }
  252. std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
  253. if (!CheckDims(shape)) {
  254. MS_LOG(EXCEPTION) << "Check dims failed.";
  255. }
  256. std::vector<size_t> device_shape;
  257. device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
  258. device_shape.push_back(shape[kH]);
  259. device_shape.push_back(shape[kW]);
  260. device_shape.push_back(shape[kN]);
  261. device_shape.push_back(kCubeSize);
  262. device_shape.push_back(kCubeSize);
  263. return device_shape;
  264. }
  265. std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
  266. if (!CheckDims(shape)) {
  267. MS_LOG(EXCEPTION) << "Check dims failed.";
  268. }
  269. std::vector<size_t> device_shape;
  270. const size_t c0 = 4;
  271. auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
  272. auto no = DivCeil(shape.at(kN), kCubeSize);
  273. device_shape.push_back(first_dim);
  274. device_shape.push_back(no);
  275. device_shape.push_back(kCubeSize);
  276. device_shape.push_back(kCubeSize);
  277. return device_shape;
  278. }
  279. std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
  280. if (!CheckDims(shape)) {
  281. MS_LOG(EXCEPTION) << "Check dims failed.";
  282. }
  283. std::vector<size_t> device_shape;
  284. const size_t C1 = 1;
  285. const size_t C0 = 4;
  286. device_shape.push_back(shape[kN]);
  287. device_shape.push_back(C1);
  288. device_shape.push_back(shape[kH]);
  289. device_shape.push_back(shape[kW]);
  290. device_shape.push_back(C0);
  291. return device_shape;
  292. }
  293. std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
  294. if (shape.size() < kNdhwc) {
  295. MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
  296. }
  297. return shape;
  298. }
  299. std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
  300. std::vector<size_t> shape_4d(kNchwDims, 1);
  301. switch (shape.size()) {
  302. case 0:
  303. return shape_4d;
  304. case 1:
  305. shape_4d[kC] = shape[kN];
  306. break;
  307. case 2:
  308. shape_4d[kC] = shape[kN];
  309. shape_4d[kH] = shape[kC];
  310. break;
  311. case 3:
  312. shape_4d[kC] = shape[kN];
  313. shape_4d[kH] = shape[kC];
  314. shape_4d[kW] = shape[kH];
  315. break;
  316. case 4:
  317. std::copy(shape.begin(), shape.end(), shape_4d.begin());
  318. break;
  319. default:
  320. MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
  321. }
  322. return shape_4d;
  323. }
  324. } // namespace
  325. bool IsNeedPadding(const std::string &format, const size_t shape_size) {
  326. if (shape_size == 0) {
  327. return false;
  328. }
  329. if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
  330. return false;
  331. } else if (shape_size < kNchwDims) {
  332. return true;
  333. }
  334. return false;
  335. }
  336. ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
  337. MS_EXCEPTION_IF_NULL(node);
  338. ShapeVector shape;
  339. std::vector<size_t> host_shape;
  340. if (node->isa<ValueNode>()) {
  341. auto value_node = node->cast<ValueNodePtr>();
  342. MS_EXCEPTION_IF_NULL(value_node);
  343. auto node_value = value_node->value();
  344. MS_EXCEPTION_IF_NULL(node_value);
  345. auto tensor = node_value->cast<tensor::TensorPtr>();
  346. if (tensor == nullptr) {
  347. MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
  348. }
  349. auto shape_temp = tensor->shape();
  350. (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), LongToSize);
  351. if (host_shape.empty()) {
  352. host_shape.push_back(1);
  353. }
  354. } else {
  355. host_shape = AnfAlgo::GetOutputInferShape(node, index);
  356. }
  357. if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) {
  358. host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index));
  359. }
  360. std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong);
  361. return shape;
  362. }
  363. std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis) {
  364. if (padding_axis.empty() || shape.size() != padding_axis.size()) {
  365. return PaddingShapeTo4dByDefault(shape);
  366. }
  367. std::vector<size_t> shape_4d(kNchwDims, 1);
  368. for (size_t index = 0; index < padding_axis.size(); index++) {
  369. shape_4d[padding_axis[index]] = shape[index];
  370. }
  371. return shape_4d;
  372. }
  373. std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
  374. using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
  375. const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
  376. {kOpFormat_NHWC, NhwcDeviceShape},
  377. {kOpFormat_HWCN, HwchDeviceShape},
  378. {kOpFormat_FRAC_Z, FracZDeviceShape},
  379. {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
  380. {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
  381. {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
  382. {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
  383. {kOpFormat_NDHWC, NdhwcDeviceShape}};
  384. if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
  385. return shape;
  386. }
  387. auto temp_shape = shape;
  388. std::vector<size_t> device_shape;
  389. if (format == kOpFormat_FRAC_NZ) {
  390. if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
  391. // For [1] and [1024] shape we can trait it as NZ shape
  392. return shape;
  393. }
  394. if (shape.size() < 2) {
  395. MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
  396. } else {
  397. (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
  398. }
  399. auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
  400. auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
  401. device_shape.push_back(w1);
  402. device_shape.push_back(h1);
  403. device_shape.push_back(kCubeSize);
  404. device_shape.push_back(kCubeSize);
  405. return device_shape;
  406. } else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
  407. const size_t c0 = 4;
  408. const size_t h = shape.at(kN) / c0;
  409. const size_t i = shape.at(kC) - h;
  410. const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
  411. const size_t second = c0 * DivCeil(h, kCubeSize);
  412. device_shape.push_back(first);
  413. device_shape.push_back(second);
  414. device_shape.push_back(kCubeSize);
  415. device_shape.push_back(kCubeSize);
  416. return device_shape;
  417. }
  418. if (shape.size() != kNchwDims) {
  419. MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
  420. temp_shape = PaddingShapeTo4dByDefault(shape);
  421. }
  422. auto iter = device_shape_map.find(format);
  423. if (iter == device_shape_map.end()) {
  424. MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
  425. }
  426. return iter->second(temp_shape);
  427. }
  428. bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
  429. if (args.host_shape.size() != kNchwDims) {
  430. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  431. return false;
  432. }
  433. MS_EXCEPTION_IF_NULL(size);
  434. MS_EXCEPTION_IF_NULL(total_size);
  435. *size = abstract::TypeIdSize(args.src_data_type);
  436. if (*size < 1) {
  437. MS_LOG(ERROR) << "Illegal dtype.";
  438. return false;
  439. }
  440. *total_size = abstract::ShapeSize(args.device_shape) * (*size);
  441. if (*total_size != args.device_size) {
  442. MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
  443. return false;
  444. }
  445. return true;
  446. }
  447. bool TransDataType(const TypeIdArgs &args, void *result) {
  448. MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
  449. << TypeIdLabel(args.device_data_type);
  450. MS_EXCEPTION_IF_NULL(result);
  451. std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
  452. auto iter = mode_map.find(type_info);
  453. if (iter == mode_map.end()) {
  454. MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
  455. << ", dst_type:" << TypeIdLabel(args.device_data_type);
  456. return false;
  457. }
  458. auto trans_mode = iter->second;
  459. if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
  460. MS_LOG(ERROR) << "Failed to trans datatype..";
  461. return false;
  462. }
  463. return true;
  464. }
  465. bool TransFormat(const FormatArgs &args, void *result) {
  466. using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
  467. const std::map<std::string, FormatTransfer> format_trans_map{
  468. {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
  469. {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
  470. {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
  471. MS_LOG(DEBUG) << "Start trans format.";
  472. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  473. MS_LOG(ERROR) << "Invalid datatype..";
  474. return false;
  475. }
  476. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  477. return NchwTo4D(args, result);
  478. }
  479. auto iter = format_trans_map.find(args.device_format);
  480. if (iter == format_trans_map.end()) {
  481. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  482. }
  483. return iter->second(args, result);
  484. }
  485. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
  486. using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
  487. const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw},
  488. {kOpFormat_FRAC_NZ, FracNzToNchw},
  489. {kOpFormat_NC1HWC0, Nc1hwc0ToNchw},
  490. {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
  491. {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
  492. MS_LOG(DEBUG) << "Start trans format.";
  493. if (abstract::TypeIdSize(args.src_data_type) < 1) {
  494. MS_LOG(ERROR) << "Invalid datatype..";
  495. return false;
  496. }
  497. if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
  498. return ToNchw(args, result);
  499. }
  500. auto iter = format_trans_map.find(args.device_format);
  501. if (iter == format_trans_map.end()) {
  502. MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
  503. }
  504. return iter->second(args, result);
  505. }
  506. bool NchwTo4D(const FormatArgs &args, void *result) {
  507. // trans nchw to 4d
  508. MS_LOG(DEBUG) << "Trans format from nchw to 4d.";
  509. MS_EXCEPTION_IF_NULL(result);
  510. size_t size = 0;
  511. size_t total_size = 0;
  512. if (!CheckArgs(args, &size, &total_size)) {
  513. MS_LOG(ERROR) << "Check args failed.";
  514. return false;
  515. }
  516. auto n = args.host_shape[kN];
  517. auto c = args.host_shape[kC];
  518. auto h = args.host_shape[kH];
  519. auto w = args.host_shape[kW];
  520. for (size_t ni = 0; ni < n; ni++) {
  521. for (size_t ci = 0; ci < c; ci++) {
  522. for (size_t hi = 0; hi < h; hi++) {
  523. for (size_t wi = 0; wi < w; wi++) {
  524. auto src_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  525. auto dst_idx = 0;
  526. if (args.device_format == kOpFormat_NHWC) {
  527. dst_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  528. } else if (args.device_format == kOpFormat_HWCN) {
  529. dst_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  530. }
  531. SetData(size, false, src_idx, dst_idx, args, result);
  532. }
  533. }
  534. }
  535. }
  536. return true;
  537. }
  538. bool ToNchw(const FormatArgs &args, void *result) {
  539. MS_LOG(DEBUG) << "Trans format to nchw from 4d.";
  540. MS_EXCEPTION_IF_NULL(result);
  541. size_t size = 0;
  542. size_t total_size = 0;
  543. if (!CheckArgs(args, &size, &total_size)) {
  544. MS_LOG(ERROR) << "Check args failed.";
  545. return false;
  546. }
  547. auto n = args.host_shape[kN];
  548. auto c = args.host_shape[kC];
  549. auto h = args.host_shape[kH];
  550. auto w = args.host_shape[kW];
  551. for (size_t ni = 0; ni < n; ni++) {
  552. for (size_t ci = 0; ci < c; ci++) {
  553. for (size_t hi = 0; hi < h; hi++) {
  554. for (size_t wi = 0; wi < w; wi++) {
  555. auto dst_idx = ni * c * h * w + ci * h * w + hi * w + wi;
  556. auto src_idx = 0;
  557. if (args.device_format == kOpFormat_NHWC) {
  558. src_idx = ni * h * w * c + hi * w * c + wi * c + ci;
  559. } else if (args.device_format == kOpFormat_HWCN) {
  560. src_idx = hi * w * c * n + wi * c * n + ci * n + ni;
  561. }
  562. SetData(size, false, src_idx, dst_idx, args, result);
  563. }
  564. }
  565. }
  566. }
  567. return true;
  568. }
  569. bool NchwToFracZ(const FormatArgs &args, void *result) {
  570. MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
  571. MS_EXCEPTION_IF_NULL(result);
  572. if (args.host_shape.size() != kNchwDims) {
  573. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  574. return false;
  575. }
  576. auto size = abstract::TypeIdSize(args.src_data_type);
  577. if (size < 1) {
  578. MS_LOG(ERROR) << "Illegal dtype.";
  579. return false;
  580. }
  581. auto n = args.host_shape[kN];
  582. auto c = args.host_shape[kC];
  583. auto h = args.host_shape[kH];
  584. auto w = args.host_shape[kW];
  585. auto c0 = CubeSizeByType(args.src_data_type);
  586. if (c0 < 1) {
  587. MS_LOG(ERROR) << "Illegal dtype.";
  588. return false;
  589. }
  590. auto c1 = DivCeil(c, c0);
  591. auto hw = h * w;
  592. auto chw = c * hw;
  593. auto hwc0 = hw * c0;
  594. auto nchw = n * chw;
  595. auto hf_cnt = DivCeil(n, kCubeSize);
  596. auto vf_cnt = c1 * hw;
  597. auto fractal_ele_cnt = c0 * kCubeSize;
  598. auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
  599. auto dst_size = total_ele_cnt * size;
  600. if (dst_size != args.device_size) {
  601. MS_LOG(ERROR) << "Illegal total data size."
  602. << "dst size is :" << dst_size << "device size is :" << args.device_size;
  603. return false;
  604. }
  605. for (size_t vfi = 0; vfi < vf_cnt; vfi++) {
  606. auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
  607. for (size_t hfi = 0; hfi < hf_cnt; hfi++) {
  608. auto gfi = vf_base_i + hfi; // global fractal matrix index
  609. auto src_n_offset = hfi * chw * kCubeSize;
  610. auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
  611. for (size_t row = 0; row < c0; row++) {
  612. auto src_ci = vfi / hw * c0 + row;
  613. auto src_row_offset = src_f_offset + row * hw;
  614. for (size_t col = 0; col < kCubeSize; col++) {
  615. auto src_ni = hfi * kCubeSize + col;
  616. auto src_idx = src_row_offset + chw * col;
  617. auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
  618. auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
  619. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  620. }
  621. }
  622. }
  623. }
  624. return true;
  625. }
  626. bool FracZToNchw(const FormatArgs &args, void *result) {
  627. MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
  628. MS_EXCEPTION_IF_NULL(result);
  629. if (args.host_shape.size() != kNchwDims) {
  630. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  631. return false;
  632. }
  633. auto size = abstract::TypeIdSize(args.src_data_type);
  634. if (size < 1) {
  635. MS_LOG(ERROR) << "Illegal dtype.";
  636. return false;
  637. }
  638. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  639. if (total_size != args.device_size) {
  640. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  641. return false;
  642. }
  643. auto n0 = args.device_shape.at(1);
  644. auto ni = args.device_shape.at(2);
  645. auto c0 = args.device_shape.at(3);
  646. auto n = args.host_shape[kN];
  647. auto c = args.host_shape[kC];
  648. auto h = args.host_shape[kH];
  649. auto w = args.host_shape[kW];
  650. auto nc = ni * n0;
  651. auto ncc0 = nc * c0;
  652. auto wncc0 = w * ncc0;
  653. auto hwncc0 = h * wncc0;
  654. auto hw = h * w;
  655. auto chw = c * hw;
  656. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  657. size_t n_head_addr = n_idx * chw;
  658. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  659. size_t c_head_addr = n_head_addr + c_idx * hw;
  660. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  661. size_t h_head_addr = c_head_addr + h_idx * w;
  662. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  663. size_t dst_idx = h_head_addr + w_idx;
  664. size_t c1_idx = c_idx / c0;
  665. size_t c0_idx = c_idx % c0;
  666. size_t nc_idx = n_idx;
  667. size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
  668. SetData(size, false, src_idx, dst_idx, args, result);
  669. }
  670. }
  671. }
  672. }
  673. return true;
  674. }
  675. bool NchwToFracZc04(const FormatArgs &args, void *result) {
  676. // trans nchw to FracZc04
  677. MS_LOG(DEBUG) << "Trans format from nchw to FracZc04.";
  678. MS_EXCEPTION_IF_NULL(result);
  679. size_t size = 0;
  680. size_t total_size = 0;
  681. if (!CheckArgs(args, &size, &total_size)) {
  682. MS_LOG(ERROR) << "Check args failed.";
  683. return false;
  684. }
  685. auto cube = kCubeSize;
  686. auto n = args.host_shape[kN];
  687. auto c = args.host_shape[kC];
  688. auto h = args.host_shape[kH];
  689. auto w = args.host_shape[kW];
  690. const size_t c0 = 4;
  691. auto c1 = DivCeil(c, c0);
  692. auto hwc0 = h * w * c0;
  693. auto hwc = h * w * c;
  694. auto nhwc = n * h * w * c;
  695. auto n_cnt = DivCeil(n, cube);
  696. auto v_cnt = DivCeil(h * w * c0 * c1, cube);
  697. size_t dst_idx = 0;
  698. for (size_t vi = 0; vi < v_cnt; vi++) {
  699. for (size_t ni = 0; ni < n_cnt; ni++) {
  700. for (size_t col = 0; col < cube; col++) {
  701. for (size_t row = 0; row < cube; row++) {
  702. size_t cur_cube_n = cube * ni + col;
  703. size_t cur_cube_c1hwc0 = cube * vi + row;
  704. auto desc_g = cur_cube_n / n;
  705. auto desc_n = cur_cube_n % n;
  706. auto desc_c1 = cur_cube_c1hwc0 / hwc0;
  707. auto desc_c0 = cur_cube_c1hwc0 % c0;
  708. auto desc_h = (cur_cube_c1hwc0 - hwc0 * desc_c1) / (w * c0);
  709. auto desc_w = (cur_cube_c1hwc0 - hwc0 * desc_c1 - w * c0 * desc_h) / c0;
  710. auto c_idx = desc_c1 * c0 + desc_c0;
  711. auto src_idx = desc_g * nhwc + desc_n * hwc + c_idx * h * w + desc_h * w + desc_w;
  712. auto pad_zero = desc_g >= 1 || desc_n >= n || c_idx >= c;
  713. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  714. dst_idx++;
  715. }
  716. }
  717. }
  718. }
  719. return true;
  720. }
  721. bool NchwToNc1hwc04(const FormatArgs &args, void *result) {
  722. MS_LOG(DEBUG) << "Trans format from nchw to Nc1hwc04.";
  723. return NchwToNc1hwc0(args, result);
  724. }
  725. bool Nc1hwc04ToNchw(const FormatArgs &args, void *result) {
  726. MS_LOG(DEBUG) << "Trans format from Nc1hwc04 to nchw.";
  727. return Nc1hwc0ToNchw(args, result);
  728. }
  729. bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
  730. MS_EXCEPTION_IF_NULL(hw_shape);
  731. if (host_shape.empty()) {
  732. MS_LOG(ERROR) << "Size of vector is 0.";
  733. return false;
  734. }
  735. switch (host_shape.size()) {
  736. case 1:
  737. hw_shape->push_back(1);
  738. hw_shape->push_back(1);
  739. hw_shape->push_back(host_shape[0]);
  740. return true;
  741. default:
  742. auto size = host_shape.size();
  743. if (size < 2) {
  744. MS_LOG(ERROR) << "Illegal size.";
  745. return false;
  746. }
  747. size_t times = 1;
  748. for (size_t i = 0; i != size - 2; i++) {
  749. times *= host_shape[i];
  750. }
  751. hw_shape->push_back(times);
  752. hw_shape->push_back(host_shape[size - 2]);
  753. hw_shape->push_back(host_shape[size - 1]);
  754. return true;
  755. }
  756. }
  757. bool NchwToFracNz(const FormatArgs &args, void *result) {
  758. MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
  759. MS_EXCEPTION_IF_NULL(result);
  760. std::vector<size_t> hw_shape;
  761. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  762. MS_LOG(ERROR) << "Trans shape failed..";
  763. return false;
  764. }
  765. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  766. MS_LOG(ERROR) << "Invalid shape size.";
  767. return false;
  768. }
  769. auto size = abstract::TypeIdSize(args.src_data_type);
  770. if (size < 1) {
  771. MS_LOG(ERROR) << "Illegal dtype";
  772. return false;
  773. }
  774. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  775. if (dst_size != args.device_size) {
  776. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  777. return false;
  778. }
  779. auto times = hw_shape.at(0);
  780. auto h = hw_shape.at(1);
  781. auto w = hw_shape.at(2);
  782. auto hw = h * w;
  783. auto shape_size = args.device_shape.size();
  784. auto w1 = args.device_shape[shape_size - 4];
  785. auto h1 = args.device_shape[shape_size - 3];
  786. auto h0 = args.device_shape[shape_size - 2];
  787. auto w0 = args.device_shape[shape_size - 1];
  788. auto h1h0w0 = h1 * h0 * w0;
  789. auto w1h1h0w0 = w1 * h1h0w0;
  790. auto num_w1 = w / w0;
  791. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  792. auto times_head = times_idx * w1h1h0w0;
  793. auto src_times_head = times_idx * hw;
  794. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  795. auto h1h0_head = times_head + h1h0_idx * w0;
  796. auto src_h_head = src_times_head + h1h0_idx * w;
  797. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  798. for (size_t i = 0; i < w0; ++i) {
  799. size_t src_idx = src_h_head + w1_idx * w0 + i;
  800. size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i;
  801. SetData(size, false, src_idx, dst_idx, args, result);
  802. }
  803. }
  804. auto w1_head = num_w1 * w0;
  805. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  806. auto src_w_idx = w1_head + w0_idx;
  807. size_t dst_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  808. size_t src_idx = src_h_head + src_w_idx;
  809. SetData(size, false, src_idx, dst_idx, args, result);
  810. }
  811. }
  812. }
  813. return true;
  814. }
  815. bool FracNzToNchw(const FormatArgs &args, void *result) {
  816. MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
  817. MS_EXCEPTION_IF_NULL(result);
  818. std::vector<size_t> hw_shape;
  819. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  820. MS_LOG(ERROR) << "Trans shape failed..";
  821. return false;
  822. }
  823. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  824. MS_LOG(ERROR) << "Invalid shape size.";
  825. return false;
  826. }
  827. auto size = abstract::TypeIdSize(args.src_data_type);
  828. if (size < 1) {
  829. MS_LOG(ERROR) << "Illegal dtype";
  830. return false;
  831. }
  832. auto dst_size = abstract::ShapeSize(args.device_shape) * size;
  833. if (dst_size != args.device_size) {
  834. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  835. return false;
  836. }
  837. auto times = hw_shape.at(0);
  838. auto h = hw_shape.at(1);
  839. auto w = hw_shape.at(2);
  840. auto hw = h * w;
  841. auto shape_size = args.device_shape.size();
  842. auto w1 = args.device_shape[shape_size - 4];
  843. auto h1 = args.device_shape[shape_size - 3];
  844. auto h0 = args.device_shape[shape_size - 2];
  845. auto w0 = args.device_shape[shape_size - 1];
  846. auto h1h0w0 = h1 * h0 * w0;
  847. auto w1h1h0w0 = w1 * h1h0w0;
  848. auto num_w1 = w / w0;
  849. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  850. auto times_head = times_idx * w1h1h0w0;
  851. auto src_times_head = times_idx * hw;
  852. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  853. auto h1h0_head = times_head + h1h0_idx * w0;
  854. auto src_h_head = src_times_head + h1h0_idx * w;
  855. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  856. for (size_t i = 0; i < w0; ++i) {
  857. size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i;
  858. size_t dst_idx = src_h_head + w1_idx * w0 + i;
  859. SetData(size, false, src_idx, dst_idx, args, result);
  860. }
  861. }
  862. auto w1_head = num_w1 * w0;
  863. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  864. auto src_w_idx = w1_head + w0_idx;
  865. size_t src_idx = h1h0_head + num_w1 * h1h0w0 + w0_idx;
  866. size_t dst_idx = src_h_head + src_w_idx;
  867. SetData(size, false, src_idx, dst_idx, args, result);
  868. }
  869. }
  870. }
  871. return true;
  872. }
  873. bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
  874. MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
  875. MS_EXCEPTION_IF_NULL(result);
  876. if (args.host_shape.size() != kNchwDims) {
  877. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  878. return false;
  879. }
  880. auto size = abstract::TypeIdSize(args.src_data_type);
  881. if (size < 1) {
  882. MS_LOG(ERROR) << "Illegal dtype.";
  883. return false;
  884. }
  885. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  886. if (total_size != args.device_size) {
  887. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  888. return false;
  889. }
  890. auto n = args.host_shape[kN];
  891. auto c = args.host_shape[kC];
  892. auto h = args.host_shape[kH];
  893. auto w = args.host_shape[kW];
  894. auto c0 = CubeSizeByType(args.src_data_type);
  895. if (c0 < 1) {
  896. MS_LOG(ERROR) << "Illegal dtype.";
  897. return false;
  898. }
  899. if (args.device_format == kOpFormat_NC1HWC0_C04) {
  900. c0 = 4;
  901. }
  902. auto c1 = DivCeil(c, c0);
  903. auto hw = h * w;
  904. auto chw = c * hw;
  905. auto c1hwc0 = c1 * hw * c0;
  906. auto wc0 = w * c0;
  907. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  908. size_t n_head_addr = n_idx * c1hwc0;
  909. for (size_t c1_idx = 0; c1_idx < c1; c1_idx++) {
  910. size_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
  911. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  912. size_t h_head_addr = c1_head_addr + h_idx * wc0;
  913. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  914. size_t w_head_addr = h_head_addr + w_idx * c0;
  915. for (size_t c0_idx = 0; c0_idx < c0; c0_idx++) {
  916. size_t dst_idx = c0_idx + w_head_addr;
  917. size_t c_idx = c0_idx + c1_idx * c0;
  918. size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
  919. auto pad_zero = c_idx >= c;
  920. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  921. }
  922. }
  923. }
  924. }
  925. }
  926. return true;
  927. }
  928. bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
  929. MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
  930. MS_EXCEPTION_IF_NULL(result);
  931. if (args.host_shape.size() != kNchwDims) {
  932. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  933. return false;
  934. }
  935. auto size = abstract::TypeIdSize(args.src_data_type);
  936. if (size < 1) {
  937. MS_LOG(ERROR) << "Illegal dtype.";
  938. return false;
  939. }
  940. auto total_size = abstract::ShapeSize(args.device_shape) * size;
  941. if (total_size != args.device_size) {
  942. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  943. return false;
  944. }
  945. auto n = args.host_shape[kN];
  946. auto c = args.host_shape[kC];
  947. auto h = args.host_shape[kH];
  948. auto w = args.host_shape[kW];
  949. auto c1 = args.device_shape[1];
  950. auto c0 = args.device_shape[4];
  951. auto hw = h * w;
  952. auto chw = c * hw;
  953. auto wc0 = w * c0;
  954. auto hwc0 = h * wc0;
  955. auto c1hwc0 = c1 * hwc0;
  956. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  957. size_t n_head_addr = n_idx * chw;
  958. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  959. size_t c_head_addr = n_head_addr + c_idx * hw;
  960. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  961. size_t h_head_addr = c_head_addr + h_idx * w;
  962. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  963. size_t dst_idx = h_head_addr + w_idx;
  964. size_t c1_idx = c_idx / c0;
  965. size_t c0_idx = c_idx % c0;
  966. size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
  967. SetData(size, false, src_idx, dst_idx, args, result);
  968. }
  969. }
  970. }
  971. }
  972. return true;
  973. }
  974. bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
  975. // trans nchw to c1hwncoc0
  976. MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
  977. MS_EXCEPTION_IF_NULL(result);
  978. size_t size = 0;
  979. size_t total_size = 0;
  980. if (!CheckArgs(args, &size, &total_size)) {
  981. MS_LOG(ERROR) << "Check args failed.";
  982. return false;
  983. }
  984. auto n = args.host_shape[kN];
  985. auto c = args.host_shape[kC];
  986. auto h = args.host_shape[kH];
  987. auto w = args.host_shape[kW];
  988. const int co_idx = 4;
  989. const int c0_idx = 5;
  990. auto c1 = args.device_shape[0];
  991. auto co = args.device_shape[co_idx];
  992. auto c0 = args.device_shape[c0_idx];
  993. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  994. for (size_t h_i = 0; h_i < h; h_i++) {
  995. for (size_t w_i = 0; w_i < w; w_i++) {
  996. for (size_t n_i = 0; n_i < n; n_i++) {
  997. for (size_t co_i = 0; co_i < co; co_i++) {
  998. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  999. size_t dst_idx = c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
  1000. co_i * c0 + c0_i;
  1001. size_t c_i = c0_i + c1_i * c0;
  1002. size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1003. auto pad_zero = !(c_i < c && c0_i == co_i);
  1004. SetData(size, pad_zero, src_idx, dst_idx, args, result);
  1005. }
  1006. }
  1007. }
  1008. }
  1009. }
  1010. }
  1011. return true;
  1012. }
  1013. bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
  1014. // trans c1hwncoc0 to nchw
  1015. MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
  1016. MS_EXCEPTION_IF_NULL(result);
  1017. size_t size = 0;
  1018. size_t total_size = 0;
  1019. if (!CheckArgs(args, &size, &total_size)) {
  1020. MS_LOG(ERROR) << "Check args failed.";
  1021. return false;
  1022. }
  1023. auto n = args.host_shape[kN];
  1024. auto c = args.host_shape[kC];
  1025. auto h = args.host_shape[kH];
  1026. auto w = args.host_shape[kW];
  1027. const int co_idx = 4;
  1028. const int c0_idx = 5;
  1029. auto co = args.device_shape[co_idx];
  1030. auto c0 = args.device_shape[c0_idx];
  1031. for (size_t n_i = 0; n_i < n; n_i++) {
  1032. for (size_t c_i = 0; c_i < c; c_i++) {
  1033. for (size_t h_i = 0; h_i < h; h_i++) {
  1034. for (size_t w_i = 0; w_i < w; w_i++) {
  1035. size_t dst_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
  1036. size_t c1_i = c_i / kCubeSize;
  1037. size_t c0_i = c_i % kCubeSize;
  1038. size_t co_i = c0_i;
  1039. size_t src_idx =
  1040. c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + co_i * c0 + c0_i;
  1041. SetData(size, false, src_idx, dst_idx, args, result);
  1042. }
  1043. }
  1044. }
  1045. }
  1046. return true;
  1047. }
  1048. } // namespace trans
  1049. } // namespace mindspore