You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "transform/util.h"
  17. #include <utility>
  18. #include <sstream>
  19. #include <map>
  20. #include "securec/include/securec.h"
  21. #include "utils/convert_utils.h"
  22. #include "utils/utils.h"
  23. namespace mindspore {
  24. namespace transform {
  25. using std::make_shared;
  26. using std::shared_ptr;
  27. using std::string;
  28. using std::vector;
  29. const size_t kErrorSize = 0;
  30. vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
  31. vector<int64_t> list{};
  32. if (size <= 0) {
  33. MS_LOG(WARNING) << "size <= 0";
  34. return list;
  35. }
  36. for (int i = 0; i < size; ++i) {
  37. list.push_back(data);
  38. }
  39. return list;
  40. }
  41. static std::map<MeDataType, GeDataType> datatype_trans_map = {
  42. {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
  43. {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
  44. {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
  45. {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
  46. {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
  47. {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}};
  48. GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
  49. MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
  50. if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
  51. return datatype_trans_map[type];
  52. } else {
  53. return GeDataType::DT_UNDEFINED;
  54. }
  55. }
  56. static std::map<MeDataType, size_t> datatype_size_map = {
  57. {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float
  58. {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)},
  59. {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)},
  60. {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)},
  61. {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)},
  62. {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}};
  63. size_t TransformUtil::GetDataTypeSize(const MeDataType &type) {
  64. if (datatype_size_map.find(type) != datatype_size_map.end()) {
  65. return datatype_size_map[type];
  66. } else {
  67. MS_LOG(ERROR) << "Illegal tensor data type!";
  68. return kErrorSize;
  69. }
  70. }
  71. GeFormat TransformUtil::ConvertFormat(const string &format) {
  72. if (format == kOpFormat_NCHW) {
  73. return GeFormat::FORMAT_NCHW;
  74. } else if (format == kOpFormat_NC1HWC0) {
  75. return GeFormat::FORMAT_NC1HWC0;
  76. } else if (format == kOpFormat_NHWC) {
  77. return GeFormat::FORMAT_NHWC;
  78. } else if (format == kOpFormat_HWCN) {
  79. return GeFormat::FORMAT_HWCN;
  80. } else {
  81. return GeFormat::FORMAT_ND;
  82. }
  83. }
  84. static int64_t IntegerCastFunc(size_t temp) { return static_cast<int64_t>(temp); }
  85. std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const std::vector<int> &me_shape,
  86. const MeDataType &me_type, const std::string &format) {
  87. // convert me shape to ge shape
  88. std::vector<int64_t> ge_shape;
  89. if (me_shape.size() == 1) {
  90. ge_shape.push_back(static_cast<int64_t>(me_shape[0]));
  91. } else {
  92. ge_shape.resize(me_shape.size());
  93. (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc);
  94. }
  95. GeShape shape(ge_shape);
  96. if (shape.GetDimNum() == 0) {
  97. MS_LOG(INFO) << "The dims size of Ge tensor is zero";
  98. }
  99. // convert me format to ge format
  100. GeFormat ge_format = ConvertFormat(format);
  101. if (ge_format == GeFormat::FORMAT_ND) {
  102. MS_LOG(ERROR) << "undefined data format : " << static_cast<int>(ge_format);
  103. return nullptr;
  104. }
  105. // convert me datatype to ge datatype
  106. GeDataType data_type = ConvertDataType(me_type);
  107. if (data_type == GeDataType::DT_UNDEFINED) {
  108. MS_LOG(ERROR) << "undefined data type :" << me_type;
  109. return nullptr;
  110. }
  111. auto desc = std::make_shared<GeTensorDesc>(shape, ge_format, data_type);
  112. if (desc == nullptr) {
  113. MS_LOG(ERROR) << "Create GeTensorDesc failed!";
  114. return nullptr;
  115. }
  116. MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size();
  117. desc->SetRealDimCnt(SizeToInt(me_shape.size()));
  118. return desc;
  119. }
  120. // if failed, return empty vector.
  121. std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
  122. const std::string &format) {
  123. std::vector<GeTensorPtr> ge_tensors;
  124. for (size_t index = 0; index < me_tensors.size(); index++) {
  125. MS_EXCEPTION_IF_NULL(me_tensors[index]);
  126. MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
  127. auto shape = me_tensors[index]->shape();
  128. std::string shape_str;
  129. for (size_t i = 0; i < shape.size(); i++) {
  130. shape_str += std::to_string(shape[i]);
  131. shape_str += " ";
  132. }
  133. MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
  134. MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
  135. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
  136. if (ge_tensor_ptr != nullptr) {
  137. ge_tensors.emplace_back(ge_tensor_ptr);
  138. } else {
  139. MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
  140. ge_tensors.clear();
  141. return ge_tensors;
  142. }
  143. }
  144. return ge_tensors;
  145. }
  146. GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) {
  147. // get tensor data type size
  148. MS_EXCEPTION_IF_NULL(tensor);
  149. size_t type_size = GetDataTypeSize(tensor->data_type());
  150. if (type_size == kErrorSize) {
  151. MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
  152. return nullptr;
  153. }
  154. size_t elements_num = IntToSize(tensor->ElementsNum());
  155. if (UINT_MAX / type_size < elements_num) {
  156. MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size
  157. << " overflowed UINT_MAX: " << UINT_MAX << ".";
  158. return nullptr;
  159. }
  160. // get tensor buff size
  161. size_t data_buff_size = elements_num * type_size;
  162. if (data_buff_size == 0) {
  163. MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
  164. }
  165. // create ge tensor
  166. auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
  167. if (desc == nullptr) {
  168. MS_LOG(ERROR) << "Failed to get Tensor Desc";
  169. return nullptr;
  170. }
  171. GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc, static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
  172. if (tensor_ptr != nullptr) {
  173. MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!";
  174. }
  175. return tensor_ptr;
  176. }
  177. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
  178. const std::vector<std::vector<int>> &request_dims) {
  179. std::vector<MeTensorPtr> outputs;
  180. for (size_t index = 0; index < ge_tensors.size(); index++) {
  181. MeTensorPtr me_tensor_ptr = nullptr;
  182. if (index < request_dims.size()) {
  183. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
  184. } else {
  185. std::vector<int> empty_shape;
  186. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
  187. }
  188. if (me_tensor_ptr != nullptr) {
  189. outputs.emplace_back(me_tensor_ptr);
  190. } else {
  191. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  192. return outputs;
  193. }
  194. }
  195. return outputs;
  196. }
  197. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
  198. std::vector<MeTensorPtr> outputs;
  199. for (size_t index = 0; index < ge_tensors.size(); index++) {
  200. MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
  201. if (me_tensor_ptr != nullptr) {
  202. outputs.emplace_back(me_tensor_ptr);
  203. } else {
  204. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  205. return outputs;
  206. }
  207. }
  208. return outputs;
  209. }
  210. MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
  211. switch (type) {
  212. case GeDataType::DT_FLOAT16:
  213. return MeDataType::kNumberTypeFloat16;
  214. case GeDataType::DT_FLOAT:
  215. return MeDataType::kNumberTypeFloat32;
  216. case GeDataType::DT_DOUBLE:
  217. return MeDataType::kNumberTypeFloat64;
  218. case GeDataType::DT_INT64:
  219. return MeDataType::kNumberTypeInt64;
  220. case GeDataType::DT_INT32:
  221. return MeDataType::kNumberTypeInt32;
  222. case GeDataType::DT_INT16:
  223. return MeDataType::kNumberTypeInt16;
  224. case GeDataType::DT_INT8:
  225. return MeDataType::kNumberTypeInt8;
  226. case GeDataType::DT_BOOL:
  227. return MeDataType::kNumberTypeBool;
  228. case GeDataType::DT_UINT8:
  229. return MeDataType::kNumberTypeUInt8;
  230. case GeDataType::DT_UINT16:
  231. return MeDataType::kNumberTypeUInt16;
  232. case GeDataType::DT_UINT32:
  233. return MeDataType::kNumberTypeUInt32;
  234. case GeDataType::DT_UINT64:
  235. return MeDataType::kNumberTypeUInt64;
  236. case GeDataType::DT_UNDEFINED:
  237. case GeDataType::DT_DUAL_SUB_UINT8:
  238. case GeDataType::DT_DUAL_SUB_INT8:
  239. case GeDataType::DT_DUAL:
  240. return MeDataType::kTypeUnknown;
  241. default:
  242. return MeDataType::kTypeUnknown;
  243. }
  244. }
  245. namespace {
  246. bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector<int> &request_dims) {
  247. MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
  248. MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
  249. const int GE_DIMS = 4;
  250. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  251. if (request_dims.size() > ge_dims.size()) {
  252. MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
  253. return false;
  254. }
  255. // convert NHWC to NCHW
  256. if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) &&
  257. (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) {
  258. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  259. return true;
  260. }
  261. std::string::size_type i = 0;
  262. for (; i < request_dims.size(); i++) {
  263. if (ge_dims[i] != request_dims[i]) {
  264. MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
  265. return false;
  266. }
  267. }
  268. for (; i < ge_dims.size(); i++) {
  269. if (ge_dims[i] != 1) {
  270. MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
  271. return false;
  272. }
  273. }
  274. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  275. return true;
  276. }
  277. } // namespace
  278. GeShape TransformUtil::ConvertMeShape(const std::vector<int> &me_dims) {
  279. std::vector<int64_t> ge_dims;
  280. (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
  281. return GeShape(ge_dims);
  282. }
  283. std::vector<int> TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
  284. std::vector<int> me_dims;
  285. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  286. (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
  287. return me_dims;
  288. }
  289. std::vector<int> TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector<int> &request_dims) {
  290. vector<int> ret;
  291. if (ge_shape.GetDimNum() == 0) {
  292. MS_LOG(DEBUG) << "GeTensor's shape is scalar";
  293. return ret;
  294. }
  295. if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
  296. ret = request_dims;
  297. } else {
  298. MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
  299. ret = ConvertGeShape(ge_shape);
  300. }
  301. return ret;
  302. }
  303. MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector<int> &me_dims,
  304. const TypeId &me_type) {
  305. MeTensor me_tensor(me_type, me_dims);
  306. // Get the writable data pointer of the tensor and cast it to its data type
  307. auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c());
  308. size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
  309. MS_EXCEPTION_IF_NULL(me_data_ptr);
  310. MS_EXCEPTION_IF_NULL(ge_tensor);
  311. if (me_data_size < ge_tensor->GetSize()) {
  312. MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor ["
  313. << ge_tensor->GetSize() << " bytes]";
  314. return nullptr;
  315. }
  316. // Copy or use the writable data pointer of the ME tensor
  317. MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
  318. if (ge_tensor->GetSize() == 0) {
  319. MS_LOG(ERROR) << "GE tensor data size is zero!";
  320. return nullptr;
  321. }
  322. // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB
  323. // which is the size limit of memcpy_s
  324. memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize());
  325. return make_shared<MeTensor>(me_tensor);
  326. }
  327. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
  328. MS_EXCEPTION_IF_NULL(ge_tensor);
  329. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  330. vector<int> me_dims = ConvertGeShape(ge_shape);
  331. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  332. if (type_id == MeDataType::kTypeUnknown) {
  333. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  334. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  335. return nullptr;
  336. }
  337. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  338. }
  339. // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
  340. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector<int> &request_dims) {
  341. MS_EXCEPTION_IF_NULL(ge_tensor);
  342. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  343. vector<int> me_dims = ConvertGeShape(ge_shape, request_dims);
  344. MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  345. // Create a tensor with wanted data type and shape
  346. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  347. if (type_id == MeDataType::kTypeUnknown) {
  348. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  349. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  350. return nullptr;
  351. }
  352. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  353. }
  354. std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
  355. std::string ret;
  356. if (ge_tensor == nullptr) {
  357. MS_LOG(ERROR) << "Input ge tensor is nullptr";
  358. return ret;
  359. }
  360. MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  361. switch (ge_tensor->GetTensorDesc().GetDataType()) {
  362. case GeDataType::DT_UINT32:
  363. ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  364. break;
  365. case GeDataType::DT_FLOAT:
  366. ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  367. break;
  368. case GeDataType::DT_INT32:
  369. ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  370. break;
  371. case GeDataType::DT_DOUBLE:
  372. ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  373. break;
  374. case GeDataType::DT_INT64:
  375. ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  376. break;
  377. case GeDataType::DT_UINT64:
  378. ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  379. break;
  380. case GeDataType::DT_INT16:
  381. ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  382. break;
  383. case GeDataType::DT_UINT16:
  384. ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  385. break;
  386. case GeDataType::DT_DUAL_SUB_INT8:
  387. case GeDataType::DT_INT8:
  388. ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  389. break;
  390. case GeDataType::DT_UINT8:
  391. case GeDataType::DT_DUAL_SUB_UINT8:
  392. ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  393. break;
  394. case GeDataType::DT_FLOAT16:
  395. case GeDataType::DT_BOOL:
  396. case GeDataType::DT_UNDEFINED:
  397. case GeDataType::DT_DUAL:
  398. default:
  399. MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
  400. << " ge tensor";
  401. break;
  402. }
  403. return ret;
  404. }
  405. } // namespace transform
  406. } // namespace mindspore