You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelbin.cpp 6.6 kB

8 years ago
8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "modelbin.h"
  15. #include <stdio.h>
  16. #include <string.h>
  17. #include <vector>
  18. #include "platform.h"
  19. namespace ncnn {
  20. Mat ModelBin::load(int w, int h, int type) const
  21. {
  22. Mat m = load(w * h, type);
  23. if (m.empty())
  24. return m;
  25. return m.reshape(w, h);
  26. }
  27. Mat ModelBin::load(int w, int h, int c, int type) const
  28. {
  29. Mat m = load(w * h * c, type);
  30. if (m.empty())
  31. return m;
  32. return m.reshape(w, h, c);
  33. }
  34. #if NCNN_STDIO
  35. ModelBinFromStdio::ModelBinFromStdio(FILE* _binfp) : binfp(_binfp)
  36. {
  37. }
  38. Mat ModelBinFromStdio::load(int w, int type) const
  39. {
  40. if (!binfp)
  41. return Mat();
  42. if (type == 0)
  43. {
  44. int nread;
  45. union
  46. {
  47. struct
  48. {
  49. unsigned char f0;
  50. unsigned char f1;
  51. unsigned char f2;
  52. unsigned char f3;
  53. };
  54. unsigned int tag;
  55. } flag_struct;
  56. nread = fread(&flag_struct, sizeof(flag_struct), 1, binfp);
  57. if (nread != 1)
  58. {
  59. fprintf(stderr, "ModelBin read flag_struct failed %d\n", nread);
  60. return Mat();
  61. }
  62. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  63. if (flag_struct.tag == 0x01306B47)
  64. {
  65. // half-precision data
  66. int align_data_size = alignSize(w * sizeof(unsigned short), 4);
  67. std::vector<unsigned short> float16_weights;
  68. float16_weights.resize(align_data_size);
  69. nread = fread(float16_weights.data(), align_data_size, 1, binfp);
  70. if (nread != 1)
  71. {
  72. fprintf(stderr, "ModelBin read float16_weights failed %d\n", nread);
  73. return Mat();
  74. }
  75. return Mat::from_float16(float16_weights.data(), w);
  76. }
  77. Mat m(w);
  78. if (m.empty())
  79. return m;
  80. if (flag != 0)
  81. {
  82. // quantized data
  83. float quantization_value[256];
  84. nread = fread(quantization_value, 256 * sizeof(float), 1, binfp);
  85. if (nread != 1)
  86. {
  87. fprintf(stderr, "ModelBin read quantization_value failed %d\n", nread);
  88. return Mat();
  89. }
  90. int align_weight_data_size = alignSize(w * sizeof(unsigned char), 4);
  91. std::vector<unsigned char> index_array;
  92. index_array.resize(align_weight_data_size);
  93. nread = fread(index_array.data(), align_weight_data_size, 1, binfp);
  94. if (nread != 1)
  95. {
  96. fprintf(stderr, "ModelBin read index_array failed %d\n", nread);
  97. return Mat();
  98. }
  99. float* ptr = m;
  100. for (int i = 0; i < w; i++)
  101. {
  102. ptr[i] = quantization_value[ index_array[i] ];
  103. }
  104. }
  105. else if (flag_struct.f0 == 0)
  106. {
  107. // raw data
  108. nread = fread(m, w * sizeof(float), 1, binfp);
  109. if (nread != 1)
  110. {
  111. fprintf(stderr, "ModelBin read weight_data failed %d\n", nread);
  112. return Mat();
  113. }
  114. }
  115. return m;
  116. }
  117. else if (type == 1)
  118. {
  119. Mat m(w);
  120. if (m.empty())
  121. return m;
  122. // raw data
  123. int nread = fread(m, w * sizeof(float), 1, binfp);
  124. if (nread != 1)
  125. {
  126. fprintf(stderr, "ModelBin read weight_data failed %d\n", nread);
  127. return Mat();
  128. }
  129. return m;
  130. }
  131. else
  132. {
  133. fprintf(stderr, "ModelBin load type %d not implemented\n", type);
  134. return Mat();
  135. }
  136. return Mat();
  137. }
  138. #endif // NCNN_STDIO
  139. ModelBinFromMemory::ModelBinFromMemory(const unsigned char*& _mem) : mem(_mem)
  140. {
  141. }
  142. Mat ModelBinFromMemory::load(int w, int type) const
  143. {
  144. if (!mem)
  145. return Mat();
  146. if (type == 0)
  147. {
  148. union
  149. {
  150. struct
  151. {
  152. unsigned char f0;
  153. unsigned char f1;
  154. unsigned char f2;
  155. unsigned char f3;
  156. };
  157. unsigned int tag;
  158. } flag_struct;
  159. memcpy(&flag_struct, mem, sizeof(flag_struct));
  160. mem += sizeof(flag_struct);
  161. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  162. if (flag_struct.tag == 0x01306B47)
  163. {
  164. // half-precision data
  165. Mat m = Mat::from_float16((unsigned short*)mem, w);
  166. mem += alignSize(w * sizeof(unsigned short), 4);
  167. return m;
  168. }
  169. if (flag != 0)
  170. {
  171. // quantized data
  172. const float* quantization_value = (const float*)mem;
  173. mem += 256 * sizeof(float);
  174. const unsigned char* index_array = (const unsigned char*)mem;
  175. mem += alignSize(w * sizeof(unsigned char), 4);
  176. Mat m(w);
  177. if (m.empty())
  178. return m;
  179. float* ptr = m;
  180. for (int i = 0; i < w; i++)
  181. {
  182. ptr[i] = quantization_value[ index_array[i] ];
  183. }
  184. return m;
  185. }
  186. else if (flag_struct.f0 == 0)
  187. {
  188. // raw data
  189. Mat m = Mat(w, (float*)mem);
  190. mem += w * sizeof(float);
  191. return m;
  192. }
  193. }
  194. else if (type == 1)
  195. {
  196. // raw data
  197. Mat m = Mat(w, (float*)mem);
  198. mem += w * sizeof(float);
  199. return m;
  200. }
  201. else
  202. {
  203. fprintf(stderr, "ModelBin load type %d not implemented\n", type);
  204. return Mat();
  205. }
  206. return Mat();
  207. }
  208. ModelBinFromMatArray::ModelBinFromMatArray(const Mat* _weights) : weights(_weights)
  209. {
  210. }
  211. Mat ModelBinFromMatArray::load(int /*w*/, int /*type*/) const
  212. {
  213. if (!weights)
  214. return Mat();
  215. Mat m = weights[0];
  216. weights++;
  217. return m;
  218. }
  219. } // namespace ncnn