You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "transform/graph_ir/util.h"
  17. #include <utility>
  18. #include <map>
  19. #include "securec/include/securec.h"
  20. #include "utils/convert_utils.h"
  21. #include "utils/utils.h"
  22. namespace mindspore {
  23. namespace transform {
  24. using std::make_shared;
  25. using std::shared_ptr;
  26. using std::string;
  27. using std::vector;
  28. const size_t kErrorSize = 0;
  29. vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
  30. vector<int64_t> list{};
  31. if (size <= 0) {
  32. MS_LOG(WARNING) << "size <= 0";
  33. return list;
  34. }
  35. for (int i = 0; i < size; ++i) {
  36. list.push_back(data);
  37. }
  38. return list;
  39. }
  40. static std::map<MeDataType, GeDataType> datatype_trans_map = {
  41. {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
  42. {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
  43. {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
  44. {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
  45. {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
  46. {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}};
  47. GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
  48. MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
  49. if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
  50. return datatype_trans_map[type];
  51. } else {
  52. return GeDataType::DT_UNDEFINED;
  53. }
  54. }
  55. static std::map<MeDataType, size_t> datatype_size_map = {
  56. {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float
  57. {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)},
  58. {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)},
  59. {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)},
  60. {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)},
  61. {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}};
  62. size_t TransformUtil::GetDataTypeSize(const MeDataType &type) {
  63. if (datatype_size_map.find(type) != datatype_size_map.end()) {
  64. return datatype_size_map[type];
  65. } else {
  66. MS_LOG(ERROR) << "Illegal tensor data type!";
  67. return kErrorSize;
  68. }
  69. }
  70. GeFormat TransformUtil::ConvertFormat(const string &format) {
  71. if (format == kOpFormat_NCHW) {
  72. return GeFormat::FORMAT_NCHW;
  73. } else if (format == kOpFormat_NC1HWC0) {
  74. return GeFormat::FORMAT_NC1HWC0;
  75. } else if (format == kOpFormat_NHWC) {
  76. return GeFormat::FORMAT_NHWC;
  77. } else if (format == kOpFormat_HWCN) {
  78. return GeFormat::FORMAT_HWCN;
  79. } else if (format == kOpFormat_ND) {
  80. return GeFormat::FORMAT_ND;
  81. } else {
  82. MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
  83. return GeFormat::FORMAT_ND;
  84. }
  85. }
  86. static int64_t IntegerCastFunc(size_t temp) { return static_cast<int64_t>(temp); }
  87. std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &me_shape, const MeDataType &me_type,
  88. const std::string &format) {
  89. // convert me shape to ge shape
  90. std::vector<int64_t> ge_shape;
  91. if (me_shape.size() == 1) {
  92. ge_shape.push_back(static_cast<int64_t>(me_shape[0]));
  93. } else {
  94. ge_shape.resize(me_shape.size());
  95. (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc);
  96. }
  97. GeShape shape(ge_shape);
  98. if (shape.GetDimNum() == 0) {
  99. MS_LOG(INFO) << "The dims size of Ge tensor is zero";
  100. }
  101. // convert me format to ge format
  102. GeFormat ge_format = ConvertFormat(format);
  103. if (ge_format == GeFormat::FORMAT_ND) {
  104. MS_LOG(INFO) << "Set ND data format";
  105. }
  106. // convert me datatype to ge datatype
  107. GeDataType data_type = ConvertDataType(me_type);
  108. if (data_type == GeDataType::DT_UNDEFINED) {
  109. MS_LOG(ERROR) << "undefined data type :" << me_type;
  110. return nullptr;
  111. }
  112. auto desc = std::make_shared<GeTensorDesc>(shape, ge_format, data_type);
  113. if (desc == nullptr) {
  114. MS_LOG(ERROR) << "Create GeTensorDesc failed!";
  115. return nullptr;
  116. }
  117. MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size();
  118. desc->SetRealDimCnt(SizeToInt(me_shape.size()));
  119. return desc;
  120. }
  121. // if failed, return empty vector.
  122. std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
  123. const std::string &format) {
  124. std::vector<GeTensorPtr> ge_tensors;
  125. for (size_t index = 0; index < me_tensors.size(); index++) {
  126. MS_EXCEPTION_IF_NULL(me_tensors[index]);
  127. MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
  128. auto shape = me_tensors[index]->shape();
  129. std::string shape_str;
  130. for (size_t i = 0; i < shape.size(); i++) {
  131. shape_str += std::to_string(shape[i]);
  132. shape_str += " ";
  133. }
  134. MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
  135. MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
  136. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
  137. if (ge_tensor_ptr != nullptr) {
  138. ge_tensors.emplace_back(ge_tensor_ptr);
  139. } else {
  140. MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
  141. ge_tensors.clear();
  142. return ge_tensors;
  143. }
  144. }
  145. return ge_tensors;
  146. }
  147. GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) {
  148. // get tensor data type size
  149. MS_EXCEPTION_IF_NULL(tensor);
  150. size_t type_size = GetDataTypeSize(tensor->data_type());
  151. if (type_size == kErrorSize) {
  152. MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
  153. return nullptr;
  154. }
  155. size_t elements_num = IntToSize(tensor->ElementsNum());
  156. if (UINT_MAX / type_size < elements_num) {
  157. MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size
  158. << " overflowed UINT_MAX: " << UINT_MAX << ".";
  159. return nullptr;
  160. }
  161. // get tensor buff size
  162. size_t data_buff_size = elements_num * type_size;
  163. if (data_buff_size == 0) {
  164. MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
  165. }
  166. // create ge tensor
  167. auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
  168. if (desc == nullptr) {
  169. MS_LOG(ERROR) << "Failed to get Tensor Desc";
  170. return nullptr;
  171. }
  172. GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc, static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
  173. if (tensor_ptr != nullptr) {
  174. MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!";
  175. }
  176. return tensor_ptr;
  177. }
  178. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
  179. const std::vector<ShapeVector> &request_dims) {
  180. std::vector<MeTensorPtr> outputs;
  181. for (size_t index = 0; index < ge_tensors.size(); index++) {
  182. MeTensorPtr me_tensor_ptr = nullptr;
  183. if (index < request_dims.size()) {
  184. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
  185. } else {
  186. ShapeVector empty_shape;
  187. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
  188. }
  189. if (me_tensor_ptr != nullptr) {
  190. outputs.emplace_back(me_tensor_ptr);
  191. } else {
  192. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  193. return outputs;
  194. }
  195. }
  196. return outputs;
  197. }
  198. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
  199. std::vector<MeTensorPtr> outputs;
  200. for (size_t index = 0; index < ge_tensors.size(); index++) {
  201. MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
  202. if (me_tensor_ptr != nullptr) {
  203. outputs.emplace_back(me_tensor_ptr);
  204. } else {
  205. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  206. return outputs;
  207. }
  208. }
  209. return outputs;
  210. }
  211. MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
  212. switch (type) {
  213. case GeDataType::DT_FLOAT16:
  214. return MeDataType::kNumberTypeFloat16;
  215. case GeDataType::DT_FLOAT:
  216. return MeDataType::kNumberTypeFloat32;
  217. case GeDataType::DT_DOUBLE:
  218. return MeDataType::kNumberTypeFloat64;
  219. case GeDataType::DT_INT64:
  220. return MeDataType::kNumberTypeInt64;
  221. case GeDataType::DT_INT32:
  222. return MeDataType::kNumberTypeInt32;
  223. case GeDataType::DT_INT16:
  224. return MeDataType::kNumberTypeInt16;
  225. case GeDataType::DT_INT8:
  226. return MeDataType::kNumberTypeInt8;
  227. case GeDataType::DT_BOOL:
  228. return MeDataType::kNumberTypeBool;
  229. case GeDataType::DT_UINT8:
  230. return MeDataType::kNumberTypeUInt8;
  231. case GeDataType::DT_UINT16:
  232. return MeDataType::kNumberTypeUInt16;
  233. case GeDataType::DT_UINT32:
  234. return MeDataType::kNumberTypeUInt32;
  235. case GeDataType::DT_UINT64:
  236. return MeDataType::kNumberTypeUInt64;
  237. case GeDataType::DT_UNDEFINED:
  238. case GeDataType::DT_DUAL_SUB_UINT8:
  239. case GeDataType::DT_DUAL_SUB_INT8:
  240. case GeDataType::DT_DUAL:
  241. return MeDataType::kTypeUnknown;
  242. default:
  243. return MeDataType::kTypeUnknown;
  244. }
  245. }
  246. namespace {
  247. bool IsGeShapeCompatible(const GeShape &ge_shape, const ShapeVector &request_dims) {
  248. MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
  249. MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
  250. const int GE_DIMS = 4;
  251. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  252. if (request_dims.size() > ge_dims.size()) {
  253. MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
  254. return false;
  255. }
  256. // convert NHWC to NCHW
  257. if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) &&
  258. (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) {
  259. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  260. return true;
  261. }
  262. std::string::size_type i = 0;
  263. for (; i < request_dims.size(); i++) {
  264. if (ge_dims[i] != request_dims[i]) {
  265. MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
  266. return false;
  267. }
  268. }
  269. for (; i < ge_dims.size(); i++) {
  270. if (ge_dims[i] != 1) {
  271. MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
  272. return false;
  273. }
  274. }
  275. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  276. return true;
  277. }
  278. } // namespace
  279. GeShape TransformUtil::ConvertMeShape(const ShapeVector &me_dims) {
  280. std::vector<int64_t> ge_dims;
  281. (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
  282. return GeShape(ge_dims);
  283. }
  284. ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
  285. ShapeVector me_dims;
  286. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  287. (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
  288. return me_dims;
  289. }
  290. ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims) {
  291. vector<int64_t> ret;
  292. if (ge_shape.GetDimNum() == 0) {
  293. MS_LOG(DEBUG) << "GeTensor's shape is scalar";
  294. return ret;
  295. }
  296. if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
  297. ret = request_dims;
  298. } else {
  299. MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
  300. ret = ConvertGeShape(ge_shape);
  301. }
  302. return ret;
  303. }
  304. MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims,
  305. const TypeId &me_type) {
  306. MeTensor me_tensor(me_type, me_dims);
  307. // Get the writable data pointer of the tensor and cast it to its data type
  308. auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c());
  309. size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
  310. MS_EXCEPTION_IF_NULL(me_data_ptr);
  311. MS_EXCEPTION_IF_NULL(ge_tensor);
  312. if (me_data_size < ge_tensor->GetSize()) {
  313. MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor ["
  314. << ge_tensor->GetSize() << " bytes]";
  315. return nullptr;
  316. }
  317. // Copy or use the writable data pointer of the ME tensor
  318. MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
  319. if (ge_tensor->GetSize() == 0) {
  320. MS_LOG(ERROR) << "GE tensor data size is zero!";
  321. return nullptr;
  322. }
  323. // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB
  324. // which is the size limit of memcpy_s
  325. memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize());
  326. return make_shared<MeTensor>(me_tensor);
  327. }
  328. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
  329. MS_EXCEPTION_IF_NULL(ge_tensor);
  330. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  331. vector<int64_t> me_dims = ConvertGeShape(ge_shape);
  332. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  333. if (type_id == MeDataType::kTypeUnknown) {
  334. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  335. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  336. return nullptr;
  337. }
  338. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  339. }
  340. // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
  341. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const ShapeVector &request_dims) {
  342. MS_EXCEPTION_IF_NULL(ge_tensor);
  343. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  344. vector<int64_t> me_dims = ConvertGeShape(ge_shape, request_dims);
  345. MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  346. // Create a tensor with wanted data type and shape
  347. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  348. if (type_id == MeDataType::kTypeUnknown) {
  349. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  350. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  351. return nullptr;
  352. }
  353. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  354. }
  355. std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
  356. std::string ret;
  357. if (ge_tensor == nullptr) {
  358. MS_LOG(ERROR) << "Input ge tensor is nullptr";
  359. return ret;
  360. }
  361. MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  362. switch (ge_tensor->GetTensorDesc().GetDataType()) {
  363. case GeDataType::DT_UINT32:
  364. ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  365. break;
  366. case GeDataType::DT_FLOAT:
  367. ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  368. break;
  369. case GeDataType::DT_INT32:
  370. ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  371. break;
  372. case GeDataType::DT_DOUBLE:
  373. ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  374. break;
  375. case GeDataType::DT_INT64:
  376. ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  377. break;
  378. case GeDataType::DT_UINT64:
  379. ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  380. break;
  381. case GeDataType::DT_INT16:
  382. ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  383. break;
  384. case GeDataType::DT_UINT16:
  385. ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  386. break;
  387. case GeDataType::DT_DUAL_SUB_INT8:
  388. case GeDataType::DT_INT8:
  389. ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  390. break;
  391. case GeDataType::DT_UINT8:
  392. case GeDataType::DT_DUAL_SUB_UINT8:
  393. ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  394. break;
  395. case GeDataType::DT_FLOAT16:
  396. case GeDataType::DT_BOOL:
  397. case GeDataType::DT_UNDEFINED:
  398. case GeDataType::DT_DUAL:
  399. default:
  400. MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
  401. << " ge tensor";
  402. break;
  403. }
  404. return ret;
  405. }
  406. } // namespace transform
  407. } // namespace mindspore