You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "transform/graph_ir/util.h"
  17. #include <utility>
  18. #include <map>
  19. #include "securec/include/securec.h"
  20. #include "utils/convert_utils.h"
  21. #include "utils/utils.h"
  22. namespace mindspore {
  23. namespace transform {
  24. using std::make_shared;
  25. using std::shared_ptr;
  26. using std::string;
  27. using std::vector;
  28. const size_t kErrorSize = 0;
  29. vector<int64_t> TransformUtil::ConvertIntToList(int64_t data, int size) {
  30. vector<int64_t> list{};
  31. if (size <= 0) {
  32. MS_LOG(WARNING) << "size <= 0";
  33. return list;
  34. }
  35. for (int i = 0; i < size; ++i) {
  36. list.push_back(data);
  37. }
  38. return list;
  39. }
  40. static std::map<MeDataType, GeDataType> datatype_trans_map = {
  41. {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT},
  42. {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8},
  43. {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32},
  44. {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8},
  45. {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32},
  46. {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}};
  47. GeDataType TransformUtil::ConvertDataType(const MeDataType &type) {
  48. MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type";
  49. if (datatype_trans_map.find(type) != datatype_trans_map.end()) {
  50. return datatype_trans_map[type];
  51. } else {
  52. return GeDataType::DT_UNDEFINED;
  53. }
  54. }
  55. static std::map<MeDataType, size_t> datatype_size_map = {
  56. {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float
  57. {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)},
  58. {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)},
  59. {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)},
  60. {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)},
  61. {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}};
  62. size_t TransformUtil::GetDataTypeSize(const MeDataType &type) {
  63. if (datatype_size_map.find(type) != datatype_size_map.end()) {
  64. return datatype_size_map[type];
  65. } else {
  66. MS_LOG(ERROR) << "Illegal tensor data type!";
  67. return kErrorSize;
  68. }
  69. }
  70. GeFormat TransformUtil::ConvertFormat(const string &format) {
  71. if (format == kOpFormat_NCHW) {
  72. return GeFormat::FORMAT_NCHW;
  73. } else if (format == kOpFormat_NDHWC) {
  74. return GeFormat::FORMAT_NDHWC;
  75. } else if (format == kOpFormat_NCDHW) {
  76. return GeFormat::FORMAT_NCDHW;
  77. } else if (format == kOpFormat_DHWNC) {
  78. return GeFormat::FORMAT_DHWNC;
  79. } else if (format == kOpFormat_DHWCN) {
  80. return GeFormat::FORMAT_DHWCN;
  81. } else if (format == kOpFormat_NC1HWC0) {
  82. return GeFormat::FORMAT_NC1HWC0;
  83. } else if (format == kOpFormat_NHWC) {
  84. return GeFormat::FORMAT_NHWC;
  85. } else if (format == kOpFormat_HWCN) {
  86. return GeFormat::FORMAT_HWCN;
  87. } else if (format == kOpFormat_ND) {
  88. return GeFormat::FORMAT_ND;
  89. } else {
  90. MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
  91. return GeFormat::FORMAT_ND;
  92. }
  93. }
  94. static int64_t IntegerCastFunc(size_t temp) { return static_cast<int64_t>(temp); }
  95. std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &me_shape, const MeDataType &me_type,
  96. const std::string &format) {
  97. // convert me shape to ge shape
  98. std::vector<int64_t> ge_shape;
  99. if (me_shape.size() == 1) {
  100. ge_shape.push_back(static_cast<int64_t>(me_shape[0]));
  101. } else {
  102. ge_shape.resize(me_shape.size());
  103. (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc);
  104. }
  105. GeShape shape(ge_shape);
  106. if (shape.GetDimNum() == 0) {
  107. MS_LOG(INFO) << "The dims size of Ge tensor is zero";
  108. }
  109. // convert me format to ge format
  110. GeFormat ge_format = ConvertFormat(format);
  111. if (ge_format == GeFormat::FORMAT_ND) {
  112. MS_LOG(INFO) << "Set ND data format";
  113. }
  114. // convert me datatype to ge datatype
  115. GeDataType data_type = ConvertDataType(me_type);
  116. if (data_type == GeDataType::DT_UNDEFINED) {
  117. MS_LOG(ERROR) << "undefined data type :" << me_type;
  118. return nullptr;
  119. }
  120. auto desc = std::make_shared<GeTensorDesc>(shape, ge_format, data_type);
  121. if (desc == nullptr) {
  122. MS_LOG(ERROR) << "Create GeTensorDesc failed!";
  123. return nullptr;
  124. }
  125. MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size();
  126. desc->SetRealDimCnt(SizeToInt(me_shape.size()));
  127. return desc;
  128. }
  129. // if failed, return empty vector.
  130. std::vector<GeTensorPtr> TransformUtil::ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors,
  131. const std::string &format) {
  132. std::vector<GeTensorPtr> ge_tensors;
  133. for (size_t index = 0; index < me_tensors.size(); index++) {
  134. MS_EXCEPTION_IF_NULL(me_tensors[index]);
  135. MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize();
  136. auto shape = me_tensors[index]->shape();
  137. std::string shape_str;
  138. for (size_t i = 0; i < shape.size(); i++) {
  139. shape_str += std::to_string(shape[i]);
  140. shape_str += " ";
  141. }
  142. MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}";
  143. MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type();
  144. auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format);
  145. if (ge_tensor_ptr != nullptr) {
  146. ge_tensors.emplace_back(ge_tensor_ptr);
  147. } else {
  148. MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!";
  149. ge_tensors.clear();
  150. return ge_tensors;
  151. }
  152. }
  153. return ge_tensors;
  154. }
  155. GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) {
  156. // get tensor data type size
  157. MS_EXCEPTION_IF_NULL(tensor);
  158. size_t type_size = GetDataTypeSize(tensor->data_type());
  159. if (type_size == kErrorSize) {
  160. MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size;
  161. return nullptr;
  162. }
  163. size_t elements_num = IntToSize(tensor->ElementsNum());
  164. // get tensor buff size
  165. size_t data_buff_size = elements_num * type_size;
  166. if (data_buff_size == 0) {
  167. MS_LOG(INFO) << "The Me Tensor data buff size is 0.";
  168. }
  169. // create ge tensor
  170. auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
  171. if (desc == nullptr) {
  172. MS_LOG(ERROR) << "Failed to get Tensor Desc";
  173. return nullptr;
  174. }
  175. GeTensorPtr tensor_ptr = make_shared<GeTensor>(*desc, static_cast<uint8_t *>(tensor->data_c()), data_buff_size);
  176. if (tensor_ptr != nullptr) {
  177. MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!";
  178. }
  179. return tensor_ptr;
  180. }
  181. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors,
  182. const std::vector<ShapeVector> &request_dims) {
  183. std::vector<MeTensorPtr> outputs;
  184. for (size_t index = 0; index < ge_tensors.size(); index++) {
  185. MeTensorPtr me_tensor_ptr = nullptr;
  186. if (index < request_dims.size()) {
  187. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]);
  188. } else {
  189. ShapeVector empty_shape;
  190. me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape);
  191. }
  192. if (me_tensor_ptr != nullptr) {
  193. outputs.emplace_back(me_tensor_ptr);
  194. } else {
  195. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  196. return outputs;
  197. }
  198. }
  199. return outputs;
  200. }
  201. std::vector<MeTensorPtr> TransformUtil::ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
  202. std::vector<MeTensorPtr> outputs;
  203. for (size_t index = 0; index < ge_tensors.size(); index++) {
  204. MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]);
  205. if (me_tensor_ptr != nullptr) {
  206. outputs.emplace_back(me_tensor_ptr);
  207. } else {
  208. MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!";
  209. return outputs;
  210. }
  211. }
  212. return outputs;
  213. }
  214. MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) {
  215. switch (type) {
  216. case GeDataType::DT_FLOAT16:
  217. return MeDataType::kNumberTypeFloat16;
  218. case GeDataType::DT_FLOAT:
  219. return MeDataType::kNumberTypeFloat32;
  220. case GeDataType::DT_DOUBLE:
  221. return MeDataType::kNumberTypeFloat64;
  222. case GeDataType::DT_INT64:
  223. return MeDataType::kNumberTypeInt64;
  224. case GeDataType::DT_INT32:
  225. return MeDataType::kNumberTypeInt32;
  226. case GeDataType::DT_INT16:
  227. return MeDataType::kNumberTypeInt16;
  228. case GeDataType::DT_INT8:
  229. return MeDataType::kNumberTypeInt8;
  230. case GeDataType::DT_BOOL:
  231. return MeDataType::kNumberTypeBool;
  232. case GeDataType::DT_UINT8:
  233. return MeDataType::kNumberTypeUInt8;
  234. case GeDataType::DT_UINT16:
  235. return MeDataType::kNumberTypeUInt16;
  236. case GeDataType::DT_UINT32:
  237. return MeDataType::kNumberTypeUInt32;
  238. case GeDataType::DT_UINT64:
  239. return MeDataType::kNumberTypeUInt64;
  240. case GeDataType::DT_UNDEFINED:
  241. case GeDataType::DT_DUAL_SUB_UINT8:
  242. case GeDataType::DT_DUAL_SUB_INT8:
  243. case GeDataType::DT_DUAL:
  244. return MeDataType::kTypeUnknown;
  245. default:
  246. return MeDataType::kTypeUnknown;
  247. }
  248. }
  249. namespace {
  250. bool IsGeShapeCompatible(const GeShape &ge_shape, const ShapeVector &request_dims) {
  251. MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims());
  252. MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims);
  253. const int GE_DIMS = 4;
  254. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  255. if (request_dims.size() > ge_dims.size()) {
  256. MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's";
  257. return false;
  258. }
  259. // convert NHWC to NCHW
  260. if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) &&
  261. (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) {
  262. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  263. return true;
  264. }
  265. std::string::size_type i = 0;
  266. for (; i < request_dims.size(); i++) {
  267. if (ge_dims[i] != request_dims[i]) {
  268. MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's";
  269. return false;
  270. }
  271. }
  272. for (; i < ge_dims.size(); i++) {
  273. if (ge_dims[i] != 1) {
  274. MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1";
  275. return false;
  276. }
  277. }
  278. MS_LOG(INFO) << "Ge tensor shape and request shape is compatible";
  279. return true;
  280. }
  281. } // namespace
  282. GeShape TransformUtil::ConvertMeShape(const ShapeVector &me_dims) {
  283. std::vector<int64_t> ge_dims;
  284. (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims));
  285. return GeShape(ge_dims);
  286. }
  287. ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape) {
  288. ShapeVector me_dims;
  289. std::vector<int64_t> ge_dims = ge_shape.GetDims();
  290. (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims));
  291. return me_dims;
  292. }
  293. ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims) {
  294. vector<int64_t> ret;
  295. if (ge_shape.GetDimNum() == 0) {
  296. MS_LOG(DEBUG) << "GeTensor's shape is scalar";
  297. return ret;
  298. }
  299. if (IsGeShapeCompatible(ge_shape, request_dims) == true) {
  300. ret = request_dims;
  301. } else {
  302. MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape";
  303. ret = ConvertGeShape(ge_shape);
  304. }
  305. return ret;
  306. }
  307. MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims,
  308. const TypeId &me_type) {
  309. MeTensor me_tensor(me_type, me_dims);
  310. // Get the writable data pointer of the tensor and cast it to its data type
  311. auto me_data_ptr = reinterpret_cast<uint8_t *>(me_tensor.data_c());
  312. size_t me_data_size = static_cast<size_t>(me_tensor.data().nbytes());
  313. MS_EXCEPTION_IF_NULL(me_data_ptr);
  314. MS_EXCEPTION_IF_NULL(ge_tensor);
  315. if (me_data_size < ge_tensor->GetSize()) {
  316. MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor ["
  317. << ge_tensor->GetSize() << " bytes]";
  318. return nullptr;
  319. }
  320. // Copy or use the writable data pointer of the ME tensor
  321. MS_EXCEPTION_IF_NULL(ge_tensor->GetData());
  322. if (ge_tensor->GetSize() == 0) {
  323. MS_LOG(ERROR) << "GE tensor data size is zero!";
  324. return nullptr;
  325. }
  326. // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB
  327. // which is the size limit of memcpy_s
  328. memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize());
  329. return make_shared<MeTensor>(me_tensor);
  330. }
  331. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) {
  332. MS_EXCEPTION_IF_NULL(ge_tensor);
  333. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  334. vector<int64_t> me_dims = ConvertGeShape(ge_shape);
  335. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  336. if (type_id == MeDataType::kTypeUnknown) {
  337. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  338. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  339. return nullptr;
  340. }
  341. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  342. }
  343. // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape
  344. MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const ShapeVector &request_dims) {
  345. MS_EXCEPTION_IF_NULL(ge_tensor);
  346. GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape();
  347. vector<int64_t> me_dims = ConvertGeShape(ge_shape, request_dims);
  348. MS_LOG(INFO) << "GE tensor type is " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  349. // Create a tensor with wanted data type and shape
  350. TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType());
  351. if (type_id == MeDataType::kTypeUnknown) {
  352. MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
  353. << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  354. return nullptr;
  355. }
  356. return GenerateMeTensor(ge_tensor, me_dims, type_id);
  357. }
  358. std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) {
  359. std::string ret;
  360. if (ge_tensor == nullptr) {
  361. MS_LOG(ERROR) << "Input ge tensor is nullptr";
  362. return ret;
  363. }
  364. MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType());
  365. switch (ge_tensor->GetTensorDesc().GetDataType()) {
  366. case GeDataType::DT_UINT32:
  367. ret = PrintVector(MakeVector<uint32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  368. break;
  369. case GeDataType::DT_FLOAT:
  370. ret = PrintVector(MakeVector<float_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  371. break;
  372. case GeDataType::DT_INT32:
  373. ret = PrintVector(MakeVector<int32_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  374. break;
  375. case GeDataType::DT_DOUBLE:
  376. ret = PrintVector(MakeVector<double_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  377. break;
  378. case GeDataType::DT_INT64:
  379. ret = PrintVector(MakeVector<int64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  380. break;
  381. case GeDataType::DT_UINT64:
  382. ret = PrintVector(MakeVector<uint64_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  383. break;
  384. case GeDataType::DT_INT16:
  385. ret = PrintVector(MakeVector<int16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  386. break;
  387. case GeDataType::DT_UINT16:
  388. ret = PrintVector(MakeVector<uint16_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  389. break;
  390. case GeDataType::DT_DUAL_SUB_INT8:
  391. case GeDataType::DT_INT8:
  392. ret = PrintVector(MakeVector<int8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  393. break;
  394. case GeDataType::DT_UINT8:
  395. case GeDataType::DT_DUAL_SUB_UINT8:
  396. ret = PrintVector(MakeVector<uint8_t>(ge_tensor->GetData(), ge_tensor->GetSize()));
  397. break;
  398. case GeDataType::DT_FLOAT16:
  399. case GeDataType::DT_BOOL:
  400. case GeDataType::DT_UNDEFINED:
  401. case GeDataType::DT_DUAL:
  402. default:
  403. MS_LOG(ERROR) << "Unsupported to print type:" << static_cast<int>(ge_tensor->GetTensorDesc().GetDataType())
  404. << " ge tensor";
  405. break;
  406. }
  407. return ret;
  408. }
  409. } // namespace transform
  410. } // namespace mindspore