You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

npy_header.cc 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "debug/data_dump/npy_header.h"
  17. #include <utility>
  18. #include <sstream>
  19. #include "utils/hash_map.h"
  20. #include "mindspore/core/ir/dtype.h"
  21. #include "mindspore/core/utils/log_adapter.h"
  22. #include "mindspore/core/utils/convert_utils_base.h"
  23. namespace mindspore {
  24. namespace {
  25. // npy file header start information
  26. const char kMagicPrefix[] = "\x93NUMPY";
  27. // magical length include kMagicPrefix length and version length
  28. const size_t kMagicLen = 6;
  29. const size_t kArrayAlign = 64;
  30. // first: header_length_type, second: encoding_type
  31. // header_length_type: 1 represents 2 bytes; 2 and 3 represents 4 bytes
  32. // encoding_type: 1 and 2 represents 'latin1'; 3 represents 'utf8'
  33. using version_type = std::pair<int, int>;
  34. // data type description
  35. // byteorder char: '<' is little endian; '>' is big endian; '|' is ignore(no change to byte order)
  36. // type char: 'b' represents bool; 'u' represents uint; 'i' represents int; 'f' represents float
  37. struct DtypeDescr {
  38. char byteorder;
  39. char type;
  40. size_t length;
  41. std::string str() const;
  42. };
  43. // npy file header description, includes data type description, fortran_order and array shape
  44. // fortran_order: true represents the array data Fortran-contiguous; false represents the array data C-contiguity
  45. struct NpyHeader {
  46. public:
  47. DtypeDescr dtype_descr;
  48. bool fortran_order;
  49. ShapeVector shape;
  50. std::string str() const;
  51. private:
  52. std::string fortran_order_to_str() const;
  53. std::string shape_to_str() const;
  54. };
  55. std::string DtypeDescr::str() const {
  56. std::ostringstream buffer;
  57. buffer << "\'" << byteorder << type << length << "\'";
  58. return buffer.str();
  59. }
  60. std::string NpyHeader::str() const {
  61. const std::string first_field = "'descr': ";
  62. const std::string second_field = "'fortran_order': ";
  63. const std::string third_field = "'shape': ";
  64. std::ostringstream buffer;
  65. buffer << "{" << first_field << dtype_descr.str() << ", " << second_field << fortran_order_to_str() << ", "
  66. << third_field << shape_to_str() << ", }";
  67. return buffer.str();
  68. }
  69. std::string NpyHeader::fortran_order_to_str() const { return fortran_order ? "True" : "False"; }
  70. std::string NpyHeader::shape_to_str() const {
  71. std::ostringstream buffer;
  72. buffer << "(";
  73. for (const auto i : shape) {
  74. buffer << std::to_string(i) << ",";
  75. }
  76. buffer << ")";
  77. return buffer.str();
  78. }
  79. // dtype description corresponding to tensor type
  80. const mindspore::HashMap<TypeId, DtypeDescr> type_desc_map = {
  81. {kNumberTypeBool, DtypeDescr{'|', 'b', 1}}, {kNumberTypeInt8, DtypeDescr{'|', 'i', 1}},
  82. {kNumberTypeInt16, DtypeDescr{'<', 'i', 2}}, {kNumberTypeInt32, DtypeDescr{'<', 'i', 4}},
  83. {kNumberTypeInt64, DtypeDescr{'<', 'i', 8}}, {kNumberTypeUInt8, DtypeDescr{'|', 'u', 1}},
  84. {kNumberTypeUInt16, DtypeDescr{'<', 'u', 2}}, {kNumberTypeUInt32, DtypeDescr{'<', 'u', 4}},
  85. {kNumberTypeUInt64, DtypeDescr{'<', 'u', 8}}, {kNumberTypeFloat16, DtypeDescr{'<', 'f', 2}},
  86. {kNumberTypeFloat32, DtypeDescr{'<', 'f', 4}}, {kNumberTypeFloat64, DtypeDescr{'<', 'f', 8}},
  87. };
  88. } // namespace
  89. void int_to_byte(size_t number, char *byte, size_t length) {
  90. const size_t byte_len = 8;
  91. const size_t mask = 0xff;
  92. for (size_t i = 0; i < length; i++) {
  93. byte[i] = (number >> (i * byte_len)) & mask;
  94. }
  95. }
  96. std::string GenerateNpyHeader(const ShapeVector &shape, TypeId type_id, bool fortran_order) {
  97. auto type_desc = type_desc_map.find(type_id);
  98. if (type_desc == type_desc_map.end()) {
  99. MS_LOG(INFO) << "Not support dump the " << TypeIdToType(type_id)->ToString() << " data to npy file.";
  100. return std::string();
  101. }
  102. NpyHeader npy_header{type_desc->second, fortran_order, shape};
  103. std::string header_str = npy_header.str();
  104. version_type version{1, 0};
  105. const size_t header_len = header_str.length();
  106. const size_t version_len = 2;
  107. const size_t max_len = 65535;
  108. size_t length_len = 2;
  109. size_t total_len = kMagicLen + version_len + length_len + header_len + 1;
  110. if (total_len > max_len) {
  111. version = {2, 0};
  112. length_len = 4;
  113. total_len = kMagicLen + version_len + length_len + header_len + 1;
  114. }
  115. const size_t pad_len = kArrayAlign - total_len % kArrayAlign;
  116. const size_t padding_header_len = header_len + pad_len + 1;
  117. const std::string padding(pad_len, ' ');
  118. const std::string end_line = "\n";
  119. char *length_byte = new char[length_len];
  120. int_to_byte(padding_header_len, length_byte, length_len);
  121. std::ostringstream out;
  122. (void)out.write(kMagicPrefix, SizeToLong(kMagicLen));
  123. (void)out.put(version.first);
  124. (void)out.put(version.second);
  125. (void)out.write(length_byte, SizeToLong(length_len));
  126. out << header_str << padding << end_line;
  127. delete[] length_byte;
  128. return out.str();
  129. }
  130. } // namespace mindspore