You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "transform/graph_ir/util.h"
  17. #include <utility>
  18. #include <map>
  19. #include "securec/include/securec.h"
  20. #include "utils/convert_utils.h"
  21. #include "utils/utils.h"
  22. namespace mindspore {
  23. namespace transform {
  24. using std::make_shared;
  25. using std::shared_ptr;
  26. using std::string;
  27. using std::vector;
  28. const size_t kErrorSize = 0;
  29. vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
  30. vector<int64_t> list{};
  31. if (size <= 0) {
  32. MS_LOG(WARNING) << "size <= 0";
  33. return list;
  34. }
  35. for (int i = 0; i < size; ++i) {
  36. list.push_back(data);
  37. }
  38. return list;
  39. }
  40. static std::map<MeDataType, GeDataType> datatype_trans_map = {
  41. {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
  42. {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
  43. {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
  44. {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
  45. {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
  46. {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}};
  47. GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
  48. MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
  49. if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
  50. return datatype_trans_map[type];
  51. } else {
  52. return GeDataType::DT_UNDEFINED;
  53. }
  54. }
  55. static std::map<MeDataType, size_t> datatype_size_map = {
  56. {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float
  57. {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)},
  58. {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)},
  59. {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)},
  60. {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)},
  61. {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}};
  62. size_t TransformUtil::GetDataTypeSize(const MeDataType &type) {
  63. if (datatype_size_map.find(type) != datatype_size_map.end()) {
  64. return datatype_size_map[type];
  65. } else {
  66. MS_LOG(ERROR) << "Illegal tensor data type!";
  67. return kErrorSize;
  68. }
  69. }
  70. GeFormat TransformUtil::ConvertFormat(const string &format) {
  71. if (format == kOpFormat_NCHW) {
  72. return GeFormat::FORMAT_NCHW;
  73. } else if (format == kOpFormat_NC1HWC0) {
  74. return GeFormat::FORMAT_NC1HWC0;
  75. } else if (format == kOpFormat_NHWC) {
  76. return GeFormat::FORMAT_NHWC;
  77. } else if (format == kOpFormat_HWCN) {
  78. return GeFormat::FORMAT_HWCN;
  79. } else {
  80. return GeFormat::FORMAT_ND;
  81. }
  82. }
  83. static int64_t IntegerCastFunc(size_t temp) { return static_cast<int64_t>(temp); }
  84. std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const std::vector<int> &me_shape,
  85. const MeDataType &me_type, const std::string &format) {
  86. // convert me shape to ge shape
  87. std::vector<int64_t> ge_shape;
  88. if (me_shape.size() == 1) {
  89. ge_shape.push_back(static_cast<int64_t>(me_shape[0]));
  90. } else {
  91. ge_shape.resize(me_shape.size());
  92. (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc);
  93. }
  94. GeShape shape(ge_shape);
  95. if (shape.GetDimNum() == 0) {
  96. MS_LOG(INFO) << "The dims size of Ge tensor is zero";
  97. }
  98. // convert me format to ge format
  99. GeFormat ge_format = ConvertFormat(format);
  100. if (ge_format == GeFormat::FORMAT_ND) {
  101. MS_LOG(ERROR) << "undefined data format : " << static_cast<int>(ge_format);
  102. return nullptr;
  103. }
  104. // convert me datatype to ge datatype
  105. GeDataType data_type = ConvertDataType(me_type);
  106. if (data_type == GeDataType::DT_UNDEFINED) {
  107. MS_LOG(ERROR) << "undefined data type :" << me_type;
  108. return nullptr;
  109. }
  110. auto desc = std::make_shared<GeTensorDesc>(shape, ge_format, data_type);
  111. if (desc == nullptr) {
  112. MS_LOG(ERROR) << "Create GeTensorDesc failed!";
  113. return nullptr;
  114. }
  115. MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size();
  116. desc->SetRealDimCnt(SizeToInt(me_shape.size()));
  117. return desc;
  118. }
  119. // if failed, return empty vector.
  120. std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
  121. const std::string &format) {
  122. std::vector<GeTensorPtr> ge_tensors;
  123. for (size_t index = 0; index < me_tensors.size(); index++) {
  124. MS_EXCEPTION_IF_NULL(me_tensors[index]);
  125. MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
  126. auto shape = me_tensors[index]->shape();
  127. std::string shape_str;
  128. for (size_t i = 0; i < shape.size(); i++) {
  129. shape_str += std::to_string(shape[i]);
  130. shape_str += " ";
  131. }
  132. MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
  133. MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
  134. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
  135. if (ge_tensor_ptr != nullptr) {
  136. ge_tensors.emplace_back(ge_tensor_ptr);
  137. } else {
  138. MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
  139. ge_tensors.clear();
  140. return ge_tensors;
  141. }
  142. }
  143. return ge_tensors;
  144. }
  145. GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) {
  146. // get tensor data type size
  147. MS_EXCEPTION_IF_NULL(tensor);
  148. size_t type_size = GetDataTypeSize(tensor->data_type());
  149. if (type_size == kErrorSize) {
  150. MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
  151. return nullptr;
  152. }
  153. size_t elements_num = IntToSize(tensor->ElementsNum());
  154. if (UINT_MAX / type_size < elements_num) {
  155. MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size
  156. << " overflowed UINT_MAX: " << UINT_MAX << ".";
  157. return nullptr;
  158. }
  159. // get tensor buff size
  160. size_t data_buff_size = elements_num * type_size;
  161. if (data_buff_size == 0) {
  162. MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
  163. }
  164. // create ge tensor
  165. auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
  166. if (desc == nullptr) {
  167. MS_LOG(ERROR) << "Failed to get Tensor Desc";
  168. return nullptr;
  169. }
  170. GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc, static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
  171. if (tensor_ptr != nullptr) {
  172. MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!";
  173. }
  174. return tensor_ptr;
  175. }
  176. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
  177. const std::vector<std::vector<int>> &request_dims) {
  178. std::vector<MeTensorPtr> outputs;
  179. for (size_t index = 0; index < ge_tensors.size(); index++) {
  180. MeTensorPtr me_tensor_ptr = nullptr;
  181. if (index < request_dims.size()) {
  182. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
  183. } else {
  184. std::vector<int> empty_shape;
  185. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
  186. }
  187. if (me_tensor_ptr != nullptr) {
  188. outputs.emplace_back(me_tensor_ptr);
  189. } else {
  190. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  191. return outputs;
  192. }
  193. }
  194. return outputs;
  195. }
  196. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
  197. std::vector<MeTensorPtr> outputs;
  198. for (size_t index = 0; index < ge_tensors.size(); index++) {
  199. MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
  200. if (me_tensor_ptr != nullptr) {
  201. outputs.emplace_back(me_tensor_ptr);
  202. } else {
  203. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  204. return outputs;
  205. }
  206. }
  207. return outputs;
  208. }
  209. MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
  210. switch (type) {
  211. case GeDataType::DT_FLOAT16:
  212. return MeDataType::kNumberTypeFloat16;
  213. case GeDataType::DT_FLOAT:
  214. return MeDataType::kNumberTypeFloat32;
  215. case GeDataType::DT_DOUBLE:
  216. return MeDataType::kNumberTypeFloat64;
  217. case GeDataType::DT_INT64:
  218. return MeDataType::kNumberTypeInt64;
  219. case GeDataType::DT_INT32:
  220. return MeDataType::kNumberTypeInt32;
  221. case GeDataType::DT_INT16:
  222. return MeDataType::kNumberTypeInt16;
  223. case GeDataType::DT_INT8:
  224. return MeDataType::kNumberTypeInt8;
  225. case GeDataType::DT_BOOL:
  226. return MeDataType::kNumberTypeBool;
  227. case GeDataType::DT_UINT8:
  228. return MeDataType::kNumberTypeUInt8;
  229. case GeDataType::DT_UINT16:
  230. return MeDataType::kNumberTypeUInt16;
  231. case GeDataType::DT_UINT32:
  232. return MeDataType::kNumberTypeUInt32;
  233. case GeDataType::DT_UINT64:
  234. return MeDataType::kNumberTypeUInt64;
  235. case GeDataType::DT_UNDEFINED:
  236. case GeDataType::DT_DUAL_SUB_UINT8:
  237. case GeDataType::DT_DUAL_SUB_INT8:
  238. case GeDataType::DT_DUAL:
  239. return MeDataType::kTypeUnknown;
  240. default:
  241. return MeDataType::kTypeUnknown;
  242. }
  243. }
  244. namespace {
  245. bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector<int> &request_dims) {
  246. MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
  247. MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
  248. const int GE_DIMS = 4;
  249. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  250. if (request_dims.size() > ge_dims.size()) {
  251. MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
  252. return false;
  253. }
  254. // convert NHWC to NCHW
  255. if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) &&
  256. (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) {
  257. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  258. return true;
  259. }
  260. std::string::size_type i = 0;
  261. for (; i < request_dims.size(); i++) {
  262. if (ge_dims[i] != request_dims[i]) {
  263. MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
  264. return false;
  265. }
  266. }
  267. for (; i < ge_dims.size(); i++) {
  268. if (ge_dims[i] != 1) {
  269. MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
  270. return false;
  271. }
  272. }
  273. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  274. return true;
  275. }
  276. } // namespace
  277. GeShape TransformUtil::ConvertMeShape(const std::vector<int> &me_dims) {
  278. std::vector<int64_t> ge_dims;
  279. (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
  280. return GeShape(ge_dims);
  281. }
  282. std::vector<int> TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
  283. std::vector<int> me_dims;
  284. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  285. (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
  286. return me_dims;
  287. }
  288. std::vector<int> TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector<int> &request_dims) {
  289. vector<int> ret;
  290. if (ge_shape.GetDimNum() == 0) {
  291. MS_LOG(DEBUG) << "GeTensor's shape is scalar";
  292. return ret;
  293. }
  294. if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
  295. ret = request_dims;
  296. } else {
  297. MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
  298. ret = ConvertGeShape(ge_shape);
  299. }
  300. return ret;
  301. }
  302. MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector<int> &me_dims,
  303. const TypeId &me_type) {
  304. MeTensor me_tensor(me_type, me_dims);
  305. // Get the writable data pointer of the tensor and cast it to its data type
  306. auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c());
  307. size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
  308. MS_EXCEPTION_IF_NULL(me_data_ptr);
  309. MS_EXCEPTION_IF_NULL(ge_tensor);
  310. if (me_data_size < ge_tensor->GetSize()) {
  311. MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor ["
  312. << ge_tensor->GetSize() << " bytes]";
  313. return nullptr;
  314. }
  315. // Copy or use the writable data pointer of the ME tensor
  316. MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
  317. if (ge_tensor->GetSize() == 0) {
  318. MS_LOG(ERROR) << "GE tensor data size is zero!";
  319. return nullptr;
  320. }
  321. // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB
  322. // which is the size limit of memcpy_s
  323. memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize());
  324. return make_shared<MeTensor>(me_tensor);
  325. }
  326. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
  327. MS_EXCEPTION_IF_NULL(ge_tensor);
  328. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  329. vector<int> me_dims = ConvertGeShape(ge_shape);
  330. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  331. if (type_id == MeDataType::kTypeUnknown) {
  332. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  333. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  334. return nullptr;
  335. }
  336. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  337. }
  338. // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
  339. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector<int> &request_dims) {
  340. MS_EXCEPTION_IF_NULL(ge_tensor);
  341. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  342. vector<int> me_dims = ConvertGeShape(ge_shape, request_dims);
  343. MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  344. // Create a tensor with wanted data type and shape
  345. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  346. if (type_id == MeDataType::kTypeUnknown) {
  347. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  348. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  349. return nullptr;
  350. }
  351. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  352. }
  353. std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
  354. std::string ret;
  355. if (ge_tensor == nullptr) {
  356. MS_LOG(ERROR) << "Input ge tensor is nullptr";
  357. return ret;
  358. }
  359. MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  360. switch (ge_tensor->GetTensorDesc().GetDataType()) {
  361. case GeDataType::DT_UINT32:
  362. ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  363. break;
  364. case GeDataType::DT_FLOAT:
  365. ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  366. break;
  367. case GeDataType::DT_INT32:
  368. ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  369. break;
  370. case GeDataType::DT_DOUBLE:
  371. ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  372. break;
  373. case GeDataType::DT_INT64:
  374. ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  375. break;
  376. case GeDataType::DT_UINT64:
  377. ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  378. break;
  379. case GeDataType::DT_INT16:
  380. ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  381. break;
  382. case GeDataType::DT_UINT16:
  383. ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  384. break;
  385. case GeDataType::DT_DUAL_SUB_INT8:
  386. case GeDataType::DT_INT8:
  387. ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  388. break;
  389. case GeDataType::DT_UINT8:
  390. case GeDataType::DT_DUAL_SUB_UINT8:
  391. ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  392. break;
  393. case GeDataType::DT_FLOAT16:
  394. case GeDataType::DT_BOOL:
  395. case GeDataType::DT_UNDEFINED:
  396. case GeDataType::DT_DUAL:
  397. default:
  398. MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
  399. << " ge tensor";
  400. break;
  401. }
  402. return ret;
  403. }
  404. } // namespace transform
  405. } // namespace mindspore