You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelbin.cpp 5.7 kB

8 years ago
8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "modelbin.h"
  15. #include "datareader.h"
  16. #include <string.h>
  17. namespace ncnn {
  18. ModelBin::~ModelBin()
  19. {
  20. }
  21. Mat ModelBin::load(int w, int h, int type) const
  22. {
  23. Mat m = load(w * h, type);
  24. if (m.empty())
  25. return m;
  26. return m.reshape(w, h);
  27. }
  28. Mat ModelBin::load(int w, int h, int c, int type) const
  29. {
  30. Mat m = load(w * h * c, type);
  31. if (m.empty())
  32. return m;
  33. return m.reshape(w, h, c);
  34. }
  35. ModelBinFromDataReader::ModelBinFromDataReader(const DataReader& _dr)
  36. : dr(_dr)
  37. {
  38. }
  39. Mat ModelBinFromDataReader::load(int w, int type) const
  40. {
  41. if (type == 0)
  42. {
  43. size_t nread;
  44. union
  45. {
  46. struct
  47. {
  48. unsigned char f0;
  49. unsigned char f1;
  50. unsigned char f2;
  51. unsigned char f3;
  52. };
  53. unsigned int tag;
  54. } flag_struct;
  55. nread = dr.read(&flag_struct, sizeof(flag_struct));
  56. if (nread != sizeof(flag_struct))
  57. {
  58. NCNN_LOGE("ModelBin read flag_struct failed %zd", nread);
  59. return Mat();
  60. }
  61. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  62. if (flag_struct.tag == 0x01306B47)
  63. {
  64. // half-precision data
  65. size_t align_data_size = alignSize(w * sizeof(unsigned short), 4);
  66. std::vector<unsigned short> float16_weights;
  67. float16_weights.resize(align_data_size);
  68. nread = dr.read(float16_weights.data(), align_data_size);
  69. if (nread != align_data_size)
  70. {
  71. NCNN_LOGE("ModelBin read float16_weights failed %zd", nread);
  72. return Mat();
  73. }
  74. return Mat::from_float16(float16_weights.data(), w);
  75. }
  76. else if (flag_struct.tag == 0x000D4B38)
  77. {
  78. // int8 data
  79. size_t align_data_size = alignSize(w, 4);
  80. std::vector<signed char> int8_weights;
  81. int8_weights.resize(align_data_size);
  82. nread = dr.read(int8_weights.data(), align_data_size);
  83. if (nread != align_data_size)
  84. {
  85. NCNN_LOGE("ModelBin read int8_weights failed %zd", nread);
  86. return Mat();
  87. }
  88. Mat m(w, (size_t)1u);
  89. if (m.empty())
  90. return m;
  91. memcpy(m.data, int8_weights.data(), w);
  92. return m;
  93. }
  94. else if (flag_struct.tag == 0x0002C056)
  95. {
  96. Mat m(w);
  97. if (m.empty())
  98. return m;
  99. // raw data with extra scaling
  100. nread = dr.read(m, w * sizeof(float));
  101. if (nread != w * sizeof(float))
  102. {
  103. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  104. return Mat();
  105. }
  106. return m;
  107. }
  108. Mat m(w);
  109. if (m.empty())
  110. return m;
  111. if (flag != 0)
  112. {
  113. // quantized data
  114. float quantization_value[256];
  115. nread = dr.read(quantization_value, 256 * sizeof(float));
  116. if (nread != 256 * sizeof(float))
  117. {
  118. NCNN_LOGE("ModelBin read quantization_value failed %zd", nread);
  119. return Mat();
  120. }
  121. size_t align_weight_data_size = alignSize(w * sizeof(unsigned char), 4);
  122. std::vector<unsigned char> index_array;
  123. index_array.resize(align_weight_data_size);
  124. nread = dr.read(index_array.data(), align_weight_data_size);
  125. if (nread != align_weight_data_size)
  126. {
  127. NCNN_LOGE("ModelBin read index_array failed %zd", nread);
  128. return Mat();
  129. }
  130. float* ptr = m;
  131. for (int i = 0; i < w; i++)
  132. {
  133. ptr[i] = quantization_value[index_array[i]];
  134. }
  135. }
  136. else if (flag_struct.f0 == 0)
  137. {
  138. // raw data
  139. nread = dr.read(m, w * sizeof(float));
  140. if (nread != w * sizeof(float))
  141. {
  142. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  143. return Mat();
  144. }
  145. }
  146. return m;
  147. }
  148. else if (type == 1)
  149. {
  150. Mat m(w);
  151. if (m.empty())
  152. return m;
  153. // raw data
  154. size_t nread = dr.read(m, w * sizeof(float));
  155. if (nread != w * sizeof(float))
  156. {
  157. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  158. return Mat();
  159. }
  160. return m;
  161. }
  162. else
  163. {
  164. NCNN_LOGE("ModelBin load type %d not implemented", type);
  165. return Mat();
  166. }
  167. return Mat();
  168. }
  169. ModelBinFromMatArray::ModelBinFromMatArray(const Mat* _weights)
  170. : weights(_weights)
  171. {
  172. }
  173. Mat ModelBinFromMatArray::load(int /*w*/, int /*type*/) const
  174. {
  175. if (!weights)
  176. return Mat();
  177. Mat m = weights[0];
  178. weights++;
  179. return m;
  180. }
  181. } // namespace ncnn