You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans.cc 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/trans.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include <numeric>
  20. #include <utility>
  21. #include "./securec.h"
  22. #include "common/utils.h"
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "kernel/kernel.h"
  25. #include "device/convert_tensor_utils.h"
  26. #include "utils/convert_utils.h"
  27. #include "utils/log_adapter.h"
  28. #include "utils/utils.h"
  29. namespace mindspore {
  30. namespace trans {
  31. namespace {
  32. std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
  33. std::vector<size_t> shape_4d(4, 1);
  34. switch (shape.size()) {
  35. case 0:
  36. return shape_4d;
  37. case 1:
  38. shape_4d[1] = shape[0];
  39. break;
  40. case 2:
  41. shape_4d[1] = shape[0];
  42. shape_4d[2] = shape[1];
  43. break;
  44. case 3:
  45. shape_4d[1] = shape[0];
  46. shape_4d[2] = shape[1];
  47. shape_4d[3] = shape[2];
  48. break;
  49. case 4:
  50. std::copy(shape.begin(), shape.end(), shape_4d.begin());
  51. break;
  52. default:
  53. MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
  54. }
  55. return shape_4d;
  56. }
  57. } // namespace
  58. const size_t kNchwDims = 4;
  59. const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
  60. {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
  61. {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
  62. {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
  63. {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
  64. template <typename T>
  65. T Ceil(T n1, T n2) {
  66. return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
  67. }
  68. enum DataTypeTransMode {
  69. FROM_FLOAT_TO_FLOAT16,
  70. FROM_FLOAT_TO_INT32,
  71. FROM_FLOAT16_TO_FLOAT,
  72. FROM_FLOAT16_TO_INT32,
  73. FROM_INT32_TO_FLOAT,
  74. FROM_INT32_TO_FLOAT16,
  75. FROM_INT32_TO_UINT8,
  76. FROM_INT32_TO_INT8,
  77. FROM_UINT8_TO_FLOAT,
  78. FROM_UINT8_TO_INT32,
  79. FROM_INT8_TO_FLOAT,
  80. FROM_INT8_TO_INT32,
  81. FROM_INT64_TO_INT32,
  82. FROM_UINT16_TO_INT32,
  83. };
  84. const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
  85. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeFloat16), FROM_FLOAT_TO_FLOAT16},
  86. {std::pair<TypeId, TypeId>(kNumberTypeFloat32, kNumberTypeInt32), FROM_FLOAT_TO_INT32},
  87. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeFloat32), FROM_FLOAT16_TO_FLOAT},
  88. {std::pair<TypeId, TypeId>(kNumberTypeFloat16, kNumberTypeInt32), FROM_FLOAT16_TO_INT32},
  89. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat32), FROM_INT32_TO_FLOAT},
  90. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeFloat16), FROM_INT32_TO_FLOAT16},
  91. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeUInt8), FROM_INT32_TO_UINT8},
  92. {std::pair<TypeId, TypeId>(kNumberTypeInt32, kNumberTypeInt8), FROM_INT32_TO_INT8},
  93. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeFloat32), FROM_UINT8_TO_FLOAT},
  94. {std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32},
  95. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT},
  96. {std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32},
  97. {std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
  98. {std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}};
  99. void CheckMemSize(const TypeIdArgs &args) {
  100. auto src_type_size = TypeIdSize(args.host_data_type);
  101. auto dst_type_size = TypeIdSize(args.device_data_type);
  102. if (src_type_size < 1 || dst_type_size < 1) {
  103. MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
  104. }
  105. if (args.data_size / src_type_size != args.host_shape_size) {
  106. MS_LOG(EXCEPTION) << "Invalid src or dst data size.";
  107. }
  108. }
  109. template <typename SrcT, typename DstT>
  110. void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
  111. CheckMemSize(args);
  112. for (size_t idx = 0; idx != data_size; idx++) {
  113. SrcT src_data = static_cast<const SrcT *>(args.data)[idx];
  114. static_cast<DstT *>(dst)[idx] = static_cast<DstT>(src_data);
  115. }
  116. }
  117. template <typename SrcT>
  118. void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) {
  119. CheckMemSize(args);
  120. auto src_data = static_cast<const SrcT *>(args.data);
  121. auto half_data = static_cast<Eigen::half *>(dst);
  122. for (size_t i = 0; i < data_size; i++) {
  123. half_data[i] = Eigen::half(src_data[i]);
  124. }
  125. }
  126. bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) {
  127. switch (mode) {
  128. case FROM_FLOAT_TO_FLOAT16:
  129. device::FloatToHalf(dst, args.data, data_size);
  130. break;
  131. case FROM_INT32_TO_FLOAT16:
  132. TransDataSrc2Fp16<int32_t>(args, dst, data_size);
  133. break;
  134. case FROM_FLOAT16_TO_FLOAT:
  135. device::HalfToFloat(dst, args.data, data_size);
  136. break;
  137. case FROM_FLOAT_TO_INT32:
  138. TransDataSrc2Dst<float, int32_t>(args, dst, data_size);
  139. break;
  140. case FROM_FLOAT16_TO_INT32:
  141. TransDataSrc2Dst<float16, int32_t>(args, dst, data_size);
  142. break;
  143. case FROM_INT32_TO_FLOAT:
  144. TransDataSrc2Dst<int32_t, float>(args, dst, data_size);
  145. break;
  146. case FROM_INT32_TO_INT8:
  147. TransDataSrc2Dst<int32_t, int8_t>(args, dst, data_size);
  148. break;
  149. case FROM_INT32_TO_UINT8:
  150. TransDataSrc2Dst<int32_t, uint8_t>(args, dst, data_size);
  151. break;
  152. case FROM_UINT8_TO_INT32:
  153. TransDataSrc2Dst<uint8_t, int32_t>(args, dst, data_size);
  154. break;
  155. case FROM_UINT8_TO_FLOAT:
  156. TransDataSrc2Dst<uint8_t, float>(args, dst, data_size);
  157. break;
  158. case FROM_INT8_TO_FLOAT:
  159. TransDataSrc2Dst<int8_t, float>(args, dst, data_size);
  160. break;
  161. case FROM_INT8_TO_INT32:
  162. TransDataSrc2Dst<int8_t, int32_t>(args, dst, data_size);
  163. break;
  164. case FROM_INT64_TO_INT32:
  165. TransDataSrc2Dst<int64_t, int32_t>(args, dst, data_size);
  166. break;
  167. case FROM_UINT16_TO_INT32:
  168. TransDataSrc2Dst<uint16_t, int32_t>(args, dst, data_size);
  169. break;
  170. default:
  171. MS_LOG(ERROR) << "Unsupported datatype trans";
  172. return false;
  173. }
  174. return true;
  175. }
  176. size_t CubeSizeByType(const TypeId data_type) {
  177. const size_t default_error = 0;
  178. auto dt_size = TypeIdSize(data_type);
  179. if (dt_size < 1) {
  180. MS_LOG(ERROR) << "Illegal dtype.";
  181. return default_error;
  182. } else if (dt_size == 1) {
  183. return kCubeSize * 2;
  184. }
  185. return kCubeSize;
  186. }
  187. size_t ShapeSize(const std::vector<size_t> &shape) {
  188. size_t product = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
  189. return product;
  190. }
  191. size_t TypeIdSize(const TypeId data_type) {
  192. const size_t unsupported_type_error = 0;
  193. auto iter = type_map.find(data_type);
  194. if (iter != type_map.end()) {
  195. return iter->second;
  196. }
  197. return unsupported_type_error;
  198. }
  199. bool IsNeedPadding(const std::string &format, const size_t shape_size) {
  200. if (shape_size == 0) {
  201. return false;
  202. }
  203. if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
  204. return false;
  205. } else if (shape_size < 4) {
  206. return true;
  207. }
  208. return false;
  209. }
  210. std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
  211. std::vector<int> shape;
  212. std::vector<size_t> host_shape;
  213. if (node->isa<ValueNode>()) {
  214. auto value_node = node->cast<ValueNodePtr>();
  215. auto node_value = value_node->value();
  216. auto tensor = node_value->cast<tensor::TensorPtr>();
  217. if (tensor == nullptr) {
  218. MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert ";
  219. }
  220. auto shape_temp = tensor->shape();
  221. (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize);
  222. if (host_shape.empty()) {
  223. host_shape.push_back(1);
  224. }
  225. } else {
  226. host_shape = AnfAlgo::GetOutputInferShape(node, index);
  227. }
  228. if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) {
  229. host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0));
  230. }
  231. std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt);
  232. return shape;
  233. }
  234. std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<kernel::Axis> &padding_axis) {
  235. if (padding_axis.empty() || shape.size() != padding_axis.size()) {
  236. return PaddingShapeTo4dByDefault(shape);
  237. }
  238. std::vector<size_t> shape_4d(4, 1);
  239. for (size_t index = 0; index < padding_axis.size(); index++) {
  240. shape_4d[padding_axis[index]] = shape[index];
  241. }
  242. return shape_4d;
  243. }
  244. namespace {
  245. bool CheckDims(const std::vector<size_t> &shape) {
  246. if (shape.size() != 4) {
  247. MS_LOG(ERROR) << "Host shape dims shoud be 4";
  248. return false;
  249. }
  250. return true;
  251. }
  252. std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
  253. if (!CheckDims(shape)) {
  254. MS_LOG(EXCEPTION) << "Check dims failed.";
  255. }
  256. return shape;
  257. }
  258. std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
  259. if (!CheckDims(shape)) {
  260. MS_LOG(EXCEPTION) << "Ccheck dims failed.";
  261. }
  262. std::vector<size_t> device_shape;
  263. device_shape.push_back(shape[0]);
  264. device_shape.push_back(shape[2]);
  265. device_shape.push_back(shape[3]);
  266. device_shape.push_back(shape[1]);
  267. return device_shape;
  268. }
  269. std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
  270. if (!CheckDims(shape)) {
  271. MS_LOG(EXCEPTION) << "Check dims failed.";
  272. }
  273. std::vector<size_t> device_shape;
  274. device_shape.push_back(shape[2]);
  275. device_shape.push_back(shape[3]);
  276. device_shape.push_back(shape[1]);
  277. device_shape.push_back(shape[0]);
  278. return device_shape;
  279. }
  280. std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
  281. if (!CheckDims(shape)) {
  282. MS_LOG(EXCEPTION) << "Check dims failed.";
  283. }
  284. std::vector<size_t> device_shape;
  285. size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  286. size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
  287. device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
  288. device_shape.push_back(cout16 / kCubeSize);
  289. device_shape.push_back(kCubeSize);
  290. device_shape.push_back(kCubeSize);
  291. return device_shape;
  292. }
  293. std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
  294. if (!CheckDims(shape)) {
  295. MS_LOG(EXCEPTION) << "Check dims failed.";
  296. }
  297. std::vector<size_t> device_shape;
  298. size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
  299. size_t C0 = kCubeSize;
  300. device_shape.push_back(shape[0]);
  301. device_shape.push_back(C1);
  302. device_shape.push_back(shape[2]);
  303. device_shape.push_back(shape[3]);
  304. device_shape.push_back(C0);
  305. return device_shape;
  306. }
  307. std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
  308. if (!CheckDims(shape)) {
  309. MS_LOG(EXCEPTION) << "Check dims failed.";
  310. }
  311. std::vector<size_t> device_shape;
  312. device_shape.push_back((shape[1] - 1) / kCubeSize + 1);
  313. device_shape.push_back(shape[2]);
  314. device_shape.push_back(shape[3]);
  315. device_shape.push_back(shape[0]);
  316. device_shape.push_back(kCubeSize);
  317. device_shape.push_back(kCubeSize);
  318. return device_shape;
  319. }
  320. } // namespace
  321. std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
  322. using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
  323. const std::map<std::string, DeviceShapeTransfer> device_shape_map{
  324. {kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape},
  325. {kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape},
  326. {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
  327. };
  328. if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
  329. return shape;
  330. }
  331. auto temp_shape = shape;
  332. std::vector<size_t> device_shape;
  333. if (format == kOpFormat_FRAC_NZ) {
  334. if (shape.size() < 2) {
  335. MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
  336. } else {
  337. (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
  338. }
  339. auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
  340. auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
  341. device_shape.push_back(w1);
  342. device_shape.push_back(h1);
  343. device_shape.push_back(kCubeSize);
  344. device_shape.push_back(kCubeSize);
  345. return device_shape;
  346. }
  347. if (shape.size() != 4) {
  348. MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
  349. temp_shape = PaddingShapeTo4dByDefault(shape);
  350. }
  351. auto iter = device_shape_map.find(format);
  352. if (iter != device_shape_map.end()) {
  353. return iter->second(temp_shape);
  354. }
  355. MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
  356. }
  357. bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
  358. if (args.host_shape.size() != kNchwDims) {
  359. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  360. return false;
  361. }
  362. *size = TypeIdSize(args.src_data_type);
  363. if (*size < 1) {
  364. MS_LOG(ERROR) << "Illegal dtype.";
  365. return false;
  366. }
  367. *total_size = ShapeSize(args.device_shape) * (*size);
  368. if (*total_size != args.device_size) {
  369. MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
  370. return false;
  371. }
  372. return true;
  373. }
  374. bool TransDataType(const TypeIdArgs &args, void *result) {
  375. MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
  376. << TypeIdLabel(args.device_data_type);
  377. MS_EXCEPTION_IF_NULL(result);
  378. std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
  379. auto iter = mode_map.find(type_info);
  380. if (iter == mode_map.end()) {
  381. MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
  382. << ", dst_type:" << TypeIdLabel(args.device_data_type);
  383. return false;
  384. }
  385. auto trans_mode = iter->second;
  386. if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
  387. MS_LOG(ERROR) << "Failed to trans datatype..";
  388. return false;
  389. }
  390. return true;
  391. }
  392. bool TransFormat(const FormatArgs &args, void *result) {
  393. MS_LOG(DEBUG) << "Start trans format.";
  394. if (TypeIdSize(args.src_data_type) < 1) {
  395. MS_LOG(ERROR) << "Invalid datatype..";
  396. return false;
  397. }
  398. if (args.device_format == kOpFormat_FRAC_Z) {
  399. return NchwToFracZ(args, result);
  400. } else if (args.device_format == kOpFormat_FRAC_NZ) {
  401. return NchwToFracNz(args, result);
  402. } else if (args.device_format == kOpFormat_NC1HWC0) {
  403. return NchwToNc1hwc0(args, result);
  404. } else if (args.device_format == kOpFormat_C1HWNCoC0) {
  405. return NchwToC1hwncoc0(args, result);
  406. }
  407. return true;
  408. }
  409. bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
  410. MS_LOG(DEBUG) << "Start trans format.";
  411. if (TypeIdSize(args.src_data_type) < 1) {
  412. MS_LOG(ERROR) << "Invalid datatype..";
  413. return false;
  414. }
  415. if (args.device_format == kOpFormat_FRAC_Z) {
  416. return FracZToNchw(args, result);
  417. } else if (args.device_format == kOpFormat_FRAC_NZ) {
  418. return FracNzToNchw(args, result);
  419. } else if (args.device_format == kOpFormat_NC1HWC0) {
  420. return Nc1hwc0ToNchw(args, result);
  421. } else if (args.device_format == kOpFormat_C1HWNCoC0) {
  422. return C1hwncoc0ToNchw(args, result);
  423. }
  424. return true;
  425. }
  426. bool NchwToFracZ(const FormatArgs &args, void *result) {
  427. MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
  428. MS_EXCEPTION_IF_NULL(result);
  429. if (args.host_shape.size() != kNchwDims) {
  430. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  431. return false;
  432. }
  433. size_t size = TypeIdSize(args.src_data_type);
  434. if (size < 1) {
  435. MS_LOG(ERROR) << "Illegal dtype.";
  436. return false;
  437. }
  438. auto n = args.host_shape[0];
  439. auto c = args.host_shape[1];
  440. auto h = args.host_shape[2];
  441. auto w = args.host_shape[3];
  442. size_t c0 = CubeSizeByType(args.src_data_type);
  443. if (c0 < 1) {
  444. MS_LOG(ERROR) << "Illegal dtype.";
  445. return false;
  446. }
  447. size_t c1 = Ceil(c, c0);
  448. size_t hw = h * w;
  449. size_t chw = c * hw;
  450. size_t hwc0 = hw * c0;
  451. size_t nchw = n * chw;
  452. size_t hf_cnt = Ceil(n, kCubeSize);
  453. size_t vf_cnt = c1 * hw;
  454. size_t fractal_ele_cnt = c0 * kCubeSize;
  455. size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
  456. size_t dst_size = total_ele_cnt * size;
  457. if (dst_size != args.device_size) {
  458. MS_LOG(ERROR) << "Illegal total data size."
  459. << "dst size is :" << dst_size << "device size is :" << args.device_size;
  460. return false;
  461. }
  462. for (size_t vfi = 0; vfi < vf_cnt; vfi++) {
  463. auto vf_base_i = vfi * hf_cnt; // vertical fractal matrix base index
  464. for (size_t hfi = 0; hfi < hf_cnt; hfi++) {
  465. auto gfi = vf_base_i + hfi; // global fractal matrix index
  466. auto src_n_offset = hfi * chw * kCubeSize;
  467. auto src_f_offset = src_n_offset + vfi % hw + vfi / hw * hwc0;
  468. for (size_t row = 0; row < c0; row++) {
  469. auto src_ci = vfi / hw * c0 + row;
  470. auto src_row_offset = src_f_offset + row * hw;
  471. for (size_t col = 0; col < kCubeSize; col++) {
  472. auto src_ni = hfi * kCubeSize + col;
  473. auto src_offset = src_row_offset + chw * col;
  474. auto need_pad_zero = src_ni >= n || src_offset >= nchw || src_ci >= c;
  475. auto idx = gfi * fractal_ele_cnt + col * c0 + row;
  476. auto offset = idx * size;
  477. auto protected_size = dst_size - offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  478. ? dst_size - offset
  479. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  480. errno_t ret;
  481. if (need_pad_zero) {
  482. ret = memset_s(static_cast<uint8_t *>(result) + offset, protected_size, 0, size);
  483. } else {
  484. ret = memcpy_s(static_cast<uint8_t *>(result) + offset, protected_size,
  485. static_cast<uint8_t const *>(args.data) + src_offset * size, size);
  486. }
  487. if (ret != 0) {
  488. MS_LOG(ERROR) << "Failed to operate the dst memory error-code " << ret;
  489. return false;
  490. }
  491. }
  492. }
  493. }
  494. }
  495. return true;
  496. }
  497. bool FracZToNchw(const FormatArgs &args, void *result) {
  498. MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
  499. MS_EXCEPTION_IF_NULL(result);
  500. if (args.host_shape.size() != kNchwDims) {
  501. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  502. return false;
  503. }
  504. size_t size = TypeIdSize(args.src_data_type);
  505. if (size < 1) {
  506. MS_LOG(ERROR) << "Illegal dtype.";
  507. return false;
  508. }
  509. size_t total_size = ShapeSize(args.device_shape) * size;
  510. if (total_size != args.device_size) {
  511. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  512. return false;
  513. }
  514. auto n0 = args.device_shape.at(1);
  515. auto ni = args.device_shape.at(2);
  516. auto c0 = args.device_shape.at(3);
  517. auto n = args.host_shape[0];
  518. auto c = args.host_shape[1];
  519. auto h = args.host_shape[2];
  520. auto w = args.host_shape[3];
  521. size_t nc = ni * n0;
  522. size_t ncc0 = nc * c0;
  523. size_t wncc0 = w * ncc0;
  524. size_t hwncc0 = h * wncc0;
  525. size_t hw = h * w;
  526. size_t chw = c * hw;
  527. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  528. size_t n_head_addr = n_idx * chw;
  529. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  530. size_t c_head_addr = n_head_addr + c_idx * hw;
  531. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  532. size_t h_head_addr = c_head_addr + h_idx * w;
  533. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  534. size_t dst_idx = h_head_addr + w_idx;
  535. size_t c1_idx = c_idx / c0;
  536. size_t c0_idx = c_idx % c0;
  537. size_t nc_idx = n_idx;
  538. size_t src_idx = c1_idx * hwncc0 + h_idx * wncc0 + w_idx * ncc0 + nc_idx * c0 + c0_idx;
  539. auto src_offset = src_idx * size;
  540. auto dst_offset = dst_idx * size;
  541. auto protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  542. ? total_size - dst_offset
  543. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  544. auto ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  545. static_cast<uint8_t const *>(args.data) + src_offset, size);
  546. if (ret != EOK) {
  547. MS_LOG(ERROR) << "Failed to operate the dst memory error-code " << ret;
  548. return false;
  549. }
  550. }
  551. }
  552. }
  553. }
  554. return true;
  555. }
  556. bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
  557. MS_EXCEPTION_IF_NULL(hw_shape);
  558. if (host_shape.empty()) {
  559. MS_LOG(ERROR) << "Size of vector is 0.";
  560. return false;
  561. }
  562. switch (host_shape.size()) {
  563. case 1:
  564. hw_shape->push_back(1);
  565. hw_shape->push_back(1);
  566. hw_shape->push_back(host_shape[0]);
  567. return true;
  568. default:
  569. auto size = host_shape.size();
  570. if (size < 2) {
  571. MS_LOG(ERROR) << "Illegal size.";
  572. return false;
  573. }
  574. size_t times = 1;
  575. for (size_t i = 0; i != size - 2; i++) {
  576. times *= host_shape[i];
  577. }
  578. hw_shape->push_back(times);
  579. hw_shape->push_back(host_shape[size - 2]);
  580. hw_shape->push_back(host_shape[size - 1]);
  581. return true;
  582. }
  583. }
  584. bool NchwToFracNz(const FormatArgs &args, void *result) {
  585. MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
  586. MS_EXCEPTION_IF_NULL(result);
  587. std::vector<size_t> hw_shape;
  588. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  589. MS_LOG(ERROR) << "Trans shape failed..";
  590. return false;
  591. }
  592. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  593. MS_LOG(ERROR) << "Invalid shape size.";
  594. return false;
  595. }
  596. auto size = TypeIdSize(args.src_data_type);
  597. if (size < 1) {
  598. MS_LOG(ERROR) << "Illegal dtype";
  599. return false;
  600. }
  601. auto dst_size = ShapeSize(args.device_shape) * size;
  602. if (dst_size != args.device_size) {
  603. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  604. return false;
  605. }
  606. auto times = hw_shape.at(0);
  607. auto h = hw_shape.at(1);
  608. auto w = hw_shape.at(2);
  609. auto hw = h * w;
  610. auto shape_size = args.device_shape.size();
  611. auto w1 = args.device_shape[shape_size - 4];
  612. auto h1 = args.device_shape[shape_size - 3];
  613. auto h0 = args.device_shape[shape_size - 2];
  614. auto w0 = args.device_shape[shape_size - 1];
  615. auto h1h0w0 = h1 * h0 * w0;
  616. auto w1h1h0w0 = w1 * h1h0w0;
  617. auto num_w1 = w / w0;
  618. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  619. auto times_head = times_idx * w1h1h0w0;
  620. auto src_times_head = times_idx * hw;
  621. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  622. auto h1h0_head = times_head + h1h0_idx * w0;
  623. auto src_h_head = src_times_head + h1h0_idx * w;
  624. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  625. size_t dst_offset = (h1h0_head + w1_idx * h1h0w0) * size;
  626. size_t src_offset = (src_h_head + w1_idx * w0) * size;
  627. auto protected_size = dst_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  628. ? dst_size - dst_offset
  629. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  630. auto cp_ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  631. static_cast<uint8_t const *>(args.data) + src_offset, size * w0);
  632. if (cp_ret != EOK) {
  633. MS_LOG(ERROR) << "Failed to operate the dst memory, error-code " << cp_ret;
  634. return false;
  635. }
  636. }
  637. auto w1_head = num_w1 * w0;
  638. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  639. auto src_w_idx = w1_head + w0_idx;
  640. size_t dst_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size;
  641. size_t src_offset = (src_h_head + src_w_idx) * size;
  642. auto protected_size = dst_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  643. ? dst_size - dst_offset
  644. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  645. auto cp_ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  646. static_cast<uint8_t const *>(args.data) + src_offset, size);
  647. if (cp_ret != EOK) {
  648. MS_LOG(ERROR) << "Failed to operate the dst memory error-code " << cp_ret;
  649. return false;
  650. }
  651. }
  652. }
  653. }
  654. return true;
  655. }
  656. bool FracNzToNchw(const FormatArgs &args, void *result) {
  657. MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
  658. MS_EXCEPTION_IF_NULL(result);
  659. std::vector<size_t> hw_shape;
  660. if (!TransShapeToNz(args.host_shape, &hw_shape)) {
  661. MS_LOG(ERROR) << "Trans shape failed..";
  662. return false;
  663. }
  664. if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
  665. MS_LOG(ERROR) << "Invalid shape size.";
  666. return false;
  667. }
  668. auto size = TypeIdSize(args.src_data_type);
  669. if (size < 1) {
  670. MS_LOG(ERROR) << "Illegal dtype";
  671. return false;
  672. }
  673. auto dst_size = ShapeSize(args.device_shape) * size;
  674. if (dst_size != args.device_size) {
  675. MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
  676. return false;
  677. }
  678. auto times = hw_shape.at(0);
  679. auto h = hw_shape.at(1);
  680. auto w = hw_shape.at(2);
  681. auto hw = h * w;
  682. auto shape_size = args.device_shape.size();
  683. auto w1 = args.device_shape[shape_size - 4];
  684. auto h1 = args.device_shape[shape_size - 3];
  685. auto h0 = args.device_shape[shape_size - 2];
  686. auto w0 = args.device_shape[shape_size - 1];
  687. auto h1h0w0 = h1 * h0 * w0;
  688. auto w1h1h0w0 = w1 * h1h0w0;
  689. auto num_w1 = w / w0;
  690. for (size_t times_idx = 0; times_idx < times; times_idx++) {
  691. auto times_head = times_idx * w1h1h0w0;
  692. auto src_times_head = times_idx * hw;
  693. for (size_t h1h0_idx = 0; h1h0_idx < h; h1h0_idx++) {
  694. auto h1h0_head = times_head + h1h0_idx * w0;
  695. auto src_h_head = src_times_head + h1h0_idx * w;
  696. for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) {
  697. size_t src_offset = (h1h0_head + w1_idx * h1h0w0) * size;
  698. size_t dst_offset = (src_h_head + w1_idx * w0) * size;
  699. auto protected_size = dst_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  700. ? dst_size - dst_offset
  701. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  702. auto cp_ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  703. static_cast<uint8_t const *>(args.data) + src_offset, size * w0);
  704. if (cp_ret != EOK) {
  705. MS_LOG(ERROR) << "Failed to operate the dst memory, error-code " << cp_ret;
  706. return false;
  707. }
  708. }
  709. auto w1_head = num_w1 * w0;
  710. for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) {
  711. auto src_w_idx = w1_head + w0_idx;
  712. size_t src_offset = (h1h0_head + num_w1 * h1h0w0 + w0_idx) * size;
  713. size_t dst_offset = (src_h_head + src_w_idx) * size;
  714. auto protected_size = dst_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  715. ? dst_size - dst_offset
  716. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  717. auto cp_ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  718. static_cast<uint8_t const *>(args.data) + src_offset, size);
  719. if (cp_ret != EOK) {
  720. MS_LOG(ERROR) << "Failed to operate the dst memory error-code " << cp_ret;
  721. return false;
  722. }
  723. }
  724. }
  725. }
  726. return true;
  727. }
  728. bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
  729. MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
  730. MS_EXCEPTION_IF_NULL(result);
  731. if (args.host_shape.size() != kNchwDims) {
  732. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  733. return false;
  734. }
  735. size_t size = TypeIdSize(args.src_data_type);
  736. if (size < 1) {
  737. MS_LOG(ERROR) << "Illegal dtype.";
  738. return false;
  739. }
  740. auto total_size = ShapeSize(args.device_shape) * size;
  741. if (total_size != args.device_size) {
  742. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  743. return false;
  744. }
  745. auto n = args.host_shape[0];
  746. auto c = args.host_shape[1];
  747. auto h = args.host_shape[2];
  748. auto w = args.host_shape[3];
  749. size_t c0 = CubeSizeByType(args.src_data_type);
  750. if (c0 < 1) {
  751. MS_LOG(ERROR) << "Illegal dtype.";
  752. return false;
  753. }
  754. size_t c1 = Ceil(c, c0);
  755. size_t hw = h * w;
  756. size_t chw = c * hw;
  757. size_t c1hwc0 = c1 * hw * c0;
  758. size_t wc0 = w * c0;
  759. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  760. size_t n_head_addr = n_idx * c1hwc0;
  761. for (size_t c1_idx = 0; c1_idx < c1; c1_idx++) {
  762. size_t c1_head_addr = n_head_addr + c1_idx * hw * c0;
  763. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  764. size_t h_head_addr = c1_head_addr + h_idx * wc0;
  765. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  766. size_t w_head_addr = h_head_addr + w_idx * c0;
  767. for (size_t c0_idx = 0; c0_idx < c0; c0_idx++) {
  768. size_t dst_index = c0_idx + w_head_addr;
  769. size_t dst_offset = dst_index * size;
  770. auto protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  771. ? total_size - dst_offset
  772. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  773. size_t c_idx = c0_idx + c1_idx * c0;
  774. size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
  775. auto src_offset = src_idx * size;
  776. if (c_idx < c) {
  777. auto ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  778. static_cast<uint8_t const *>(args.data) + src_offset, size);
  779. if (ret != EOK) {
  780. MS_LOG(ERROR) << "Failed to operate the dst memory error-code " << ret;
  781. return false;
  782. }
  783. } else {
  784. auto ret = memset_s(static_cast<uint8_t *>(result) + dst_offset, protected_size, 0, size);
  785. if (ret != EOK) {
  786. MS_LOG(ERROR) << "Failed to operate the dst memory error-code " << ret;
  787. return false;
  788. }
  789. }
  790. }
  791. }
  792. }
  793. }
  794. }
  795. return true;
  796. }
  797. bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
  798. MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
  799. MS_EXCEPTION_IF_NULL(result);
  800. if (args.host_shape.size() != kNchwDims) {
  801. MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
  802. return false;
  803. }
  804. size_t size = TypeIdSize(args.src_data_type);
  805. if (size < 1) {
  806. MS_LOG(ERROR) << "Illegal dtype.";
  807. return false;
  808. }
  809. size_t total_size = ShapeSize(args.device_shape) * size;
  810. if (total_size != args.device_size) {
  811. MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
  812. return false;
  813. }
  814. auto n = args.host_shape[0];
  815. auto c = args.host_shape[1];
  816. auto h = args.host_shape[2];
  817. auto w = args.host_shape[3];
  818. auto c1 = args.device_shape[1];
  819. auto c0 = args.device_shape[4];
  820. size_t hw = h * w;
  821. size_t chw = c * hw;
  822. size_t wc0 = w * c0;
  823. size_t hwc0 = h * wc0;
  824. size_t c1hwc0 = c1 * hwc0;
  825. for (size_t n_idx = 0; n_idx < n; n_idx++) {
  826. size_t n_head_addr = n_idx * chw;
  827. for (size_t c_idx = 0; c_idx < c; c_idx++) {
  828. size_t c_head_addr = n_head_addr + c_idx * hw;
  829. for (size_t h_idx = 0; h_idx < h; h_idx++) {
  830. size_t h_head_addr = c_head_addr + h_idx * w;
  831. for (size_t w_idx = 0; w_idx < w; w_idx++) {
  832. size_t dst_idx = h_head_addr + w_idx;
  833. size_t c1_idx = c_idx / c0;
  834. size_t c0_idx = c_idx % c0;
  835. size_t src_idx = n_idx * c1hwc0 + c1_idx * hwc0 + h_idx * wc0 + w_idx * c0 + c0_idx;
  836. auto src_offset = src_idx * size;
  837. auto dst_offset = dst_idx * size;
  838. auto protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  839. ? total_size - dst_offset
  840. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  841. auto ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  842. static_cast<uint8_t const *>(args.data) + src_offset, size);
  843. if (ret != EOK) {
  844. MS_LOG(ERROR) << "Failed to operate the dst memory error-code " << ret;
  845. return false;
  846. }
  847. }
  848. }
  849. }
  850. }
  851. return true;
  852. }
  853. bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
  854. // trans nchw to c1hwncoc0
  855. MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0.";
  856. MS_EXCEPTION_IF_NULL(result);
  857. size_t size = 0;
  858. size_t total_size = 0;
  859. if (!CheckArgs(args, &size, &total_size)) {
  860. MS_LOG(ERROR) << "Check args failed.";
  861. return false;
  862. }
  863. auto n = args.host_shape[0];
  864. auto c = args.host_shape[1];
  865. auto h = args.host_shape[2];
  866. auto w = args.host_shape[3];
  867. auto c1 = args.device_shape[0];
  868. auto co = args.device_shape[4];
  869. auto c0 = args.device_shape[5];
  870. for (size_t c1_i = 0; c1_i < c1; c1_i++) {
  871. for (size_t h_i = 0; h_i < h; h_i++) {
  872. for (size_t w_i = 0; w_i < w; w_i++) {
  873. for (size_t n_i = 0; n_i < n; n_i++) {
  874. for (size_t co_i = 0; co_i < co; co_i++) {
  875. for (size_t c0_i = 0; c0_i < c0; c0_i++) {
  876. size_t dst_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 +
  877. n_i * co * c0 + co_i * c0 + c0_i) *
  878. size;
  879. size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  880. ? total_size - dst_offset
  881. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  882. size_t c_i = c0_i + c1_i * c0;
  883. size_t src_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size;
  884. errno_t ret;
  885. if (c_i < c && c0_i == co_i) {
  886. ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  887. static_cast<uint8_t const *>(args.data) + src_offset, size);
  888. } else {
  889. ret = memset_s(static_cast<uint8_t *>(result) + dst_offset, protected_size, 0, size);
  890. }
  891. if (ret != EOK) {
  892. MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret;
  893. return false;
  894. }
  895. }
  896. }
  897. }
  898. }
  899. }
  900. }
  901. return true;
  902. }
  903. bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
  904. // trans c1hwncoc0 to nchw
  905. MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw";
  906. MS_EXCEPTION_IF_NULL(result);
  907. size_t size = 0;
  908. size_t total_size = 0;
  909. if (!CheckArgs(args, &size, &total_size)) {
  910. MS_LOG(ERROR) << "Check args failed.";
  911. return false;
  912. }
  913. auto n = args.host_shape[0];
  914. auto c = args.host_shape[1];
  915. auto h = args.host_shape[2];
  916. auto w = args.host_shape[3];
  917. auto co = args.device_shape[4];
  918. auto c0 = args.device_shape[5];
  919. for (size_t n_i = 0; n_i < n; n_i++) {
  920. for (size_t c_i = 0; c_i < c; c_i++) {
  921. for (size_t h_i = 0; h_i < h; h_i++) {
  922. for (size_t w_i = 0; w_i < w; w_i++) {
  923. size_t dst_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size;
  924. size_t c1_i = c_i / kCubeSize;
  925. size_t c0_i = c_i % kCubeSize;
  926. size_t co_i = c0_i;
  927. size_t src_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 +
  928. co_i * c0 + c0_i) *
  929. size;
  930. size_t protected_size = total_size - dst_offset < static_cast<size_t>(SECUREC_MEM_MAX_LEN)
  931. ? total_size - dst_offset
  932. : static_cast<size_t>(SECUREC_MEM_MAX_LEN);
  933. auto ret = memcpy_s(static_cast<uint8_t *>(result) + dst_offset, protected_size,
  934. static_cast<uint8_t const *>(args.data) + src_offset, size);
  935. if (ret != EOK) {
  936. MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret;
  937. return false;
  938. }
  939. }
  940. }
  941. }
  942. }
  943. return true;
  944. }
  945. } // namespace trans
  946. } // namespace mindspore