You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelbin.cpp 7.1 kB

8 years ago
8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "modelbin.h"
  15. #include <stdio.h>
  16. #include <string.h>
  17. #include <vector>
  18. #include "platform.h"
  19. namespace ncnn {
  20. static const unsigned char* _null_mem = 0;
  21. #if NCNN_STDIO
  22. ModelBin::ModelBin(const Mat* _weights) : weights(_weights), binfp(0), mem(_null_mem)
  23. {
  24. }
  25. ModelBin::ModelBin(FILE* _binfp) : weights(0), binfp(_binfp), mem(_null_mem)
  26. {
  27. }
  28. ModelBin::ModelBin(const unsigned char*& _mem) : weights(0), binfp(0), mem(_mem)
  29. {
  30. }
  31. #else
  32. ModelBin::ModelBin(const Mat* _weights) : weights(_weights), mem(_null_mem)
  33. {
  34. }
  35. ModelBin::ModelBin(const unsigned char*& _mem) : weights(0), mem(_mem)
  36. {
  37. }
  38. #endif // NCNN_STDIO
  39. Mat ModelBin::load(int w, int type) const
  40. {
  41. if (weights)
  42. {
  43. Mat m = weights[0];
  44. weights++;
  45. return m;
  46. }
  47. #if NCNN_STDIO
  48. if (binfp)
  49. {
  50. if (type == 0)
  51. {
  52. int nread;
  53. union
  54. {
  55. struct
  56. {
  57. unsigned char f0;
  58. unsigned char f1;
  59. unsigned char f2;
  60. unsigned char f3;
  61. };
  62. unsigned int tag;
  63. } flag_struct;
  64. nread = fread(&flag_struct, sizeof(flag_struct), 1, binfp);
  65. if (nread != 1)
  66. {
  67. fprintf(stderr, "ModelBin read flag_struct failed %d\n", nread);
  68. return Mat();
  69. }
  70. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  71. if (flag_struct.tag == 0x01306B47)
  72. {
  73. // half-precision data
  74. int align_data_size = alignSize(w * sizeof(unsigned short), 4);
  75. std::vector<unsigned short> float16_weights;
  76. float16_weights.resize(align_data_size);
  77. nread = fread(float16_weights.data(), align_data_size, 1, binfp);
  78. if (nread != 1)
  79. {
  80. fprintf(stderr, "ModelBin read float16_weights failed %d\n", nread);
  81. return Mat();
  82. }
  83. return Mat::from_float16(float16_weights.data(), w);
  84. }
  85. Mat m(w);
  86. if (m.empty())
  87. return m;
  88. if (flag != 0)
  89. {
  90. // quantized data
  91. float quantization_value[256];
  92. nread = fread(quantization_value, 256 * sizeof(float), 1, binfp);
  93. if (nread != 1)
  94. {
  95. fprintf(stderr, "ModelBin read quantization_value failed %d\n", nread);
  96. return Mat();
  97. }
  98. int align_weight_data_size = alignSize(w * sizeof(unsigned char), 4);
  99. std::vector<unsigned char> index_array;
  100. index_array.resize(align_weight_data_size);
  101. nread = fread(index_array.data(), align_weight_data_size, 1, binfp);
  102. if (nread != 1)
  103. {
  104. fprintf(stderr, "ModelBin read index_array failed %d\n", nread);
  105. return Mat();
  106. }
  107. float* ptr = m;
  108. for (int i = 0; i < w; i++)
  109. {
  110. ptr[i] = quantization_value[ index_array[i] ];
  111. }
  112. }
  113. else if (flag_struct.f0 == 0)
  114. {
  115. // raw data
  116. nread = fread(m, w * sizeof(float), 1, binfp);
  117. if (nread != 1)
  118. {
  119. fprintf(stderr, "ModelBin read weight_data failed %d\n", nread);
  120. return Mat();
  121. }
  122. }
  123. return m;
  124. }
  125. else if (type == 1)
  126. {
  127. Mat m(w);
  128. if (m.empty())
  129. return m;
  130. // raw data
  131. int nread = fread(m, w * sizeof(float), 1, binfp);
  132. if (nread != 1)
  133. {
  134. fprintf(stderr, "ModelBin read weight_data failed %d\n", nread);
  135. return Mat();
  136. }
  137. return m;
  138. }
  139. else
  140. {
  141. fprintf(stderr, "ModelBin load type %d not implemented\n", type);
  142. return Mat();
  143. }
  144. }
  145. #endif // NCNN_STDIO
  146. if (type == 0)
  147. {
  148. union
  149. {
  150. struct
  151. {
  152. unsigned char f0;
  153. unsigned char f1;
  154. unsigned char f2;
  155. unsigned char f3;
  156. };
  157. unsigned int tag;
  158. } flag_struct;
  159. memcpy(&flag_struct, mem, sizeof(flag_struct));
  160. mem += sizeof(flag_struct);
  161. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  162. if (flag_struct.tag == 0x01306B47)
  163. {
  164. // half-precision data
  165. Mat m = Mat::from_float16((unsigned short*)mem, w);
  166. mem += alignSize(w * sizeof(unsigned short), 4);
  167. return m;
  168. }
  169. if (flag != 0)
  170. {
  171. // quantized data
  172. const float* quantization_value = (const float*)mem;
  173. mem += 256 * sizeof(float);
  174. const unsigned char* index_array = (const unsigned char*)mem;
  175. mem += alignSize(w * sizeof(unsigned char), 4);
  176. Mat m(w);
  177. if (m.empty())
  178. return m;
  179. float* ptr = m;
  180. for (int i = 0; i < w; i++)
  181. {
  182. ptr[i] = quantization_value[ index_array[i] ];
  183. }
  184. return m;
  185. }
  186. else if (flag_struct.f0 == 0)
  187. {
  188. // raw data
  189. Mat m = Mat(w, (float*)mem);
  190. mem += w * sizeof(float);
  191. return m;
  192. }
  193. }
  194. else if (type == 1)
  195. {
  196. // raw data
  197. Mat m = Mat(w, (float*)mem);
  198. mem += w * sizeof(float);
  199. return m;
  200. }
  201. else
  202. {
  203. fprintf(stderr, "ModelBin load type %d not implemented\n", type);
  204. return Mat();
  205. }
  206. return Mat();
  207. }
  208. Mat ModelBin::load(int w, int h, int type) const
  209. {
  210. Mat m = load(w * h, type);
  211. if (m.empty())
  212. return m;
  213. return m.reshape(w, h);
  214. }
  215. Mat ModelBin::load(int w, int h, int c, int type) const
  216. {
  217. Mat m = load(w * h * c, type);
  218. if (m.empty())
  219. return m;
  220. return m.reshape(w, h, c);
  221. }
  222. } // namespace ncnn