You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelbin.cpp 8.7 kB

8 years ago
8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "modelbin.h"
  15. #include "datareader.h"
  16. #include <string.h>
  17. namespace ncnn {
  18. ModelBin::ModelBin()
  19. {
  20. }
  21. ModelBin::~ModelBin()
  22. {
  23. }
  24. Mat ModelBin::load(int w, int h, int type) const
  25. {
  26. Mat m = load(w * h, type);
  27. if (m.empty())
  28. return m;
  29. return m.reshape(w, h);
  30. }
  31. Mat ModelBin::load(int w, int h, int c, int type) const
  32. {
  33. Mat m = load(w * h * c, type);
  34. if (m.empty())
  35. return m;
  36. return m.reshape(w, h, c);
  37. }
  38. Mat ModelBin::load(int w, int h, int d, int c, int type) const
  39. {
  40. Mat m = load(w * h * d * c, type);
  41. if (m.empty())
  42. return m;
  43. return m.reshape(w, h, d, c);
  44. }
  45. class ModelBinFromDataReaderPrivate
  46. {
  47. public:
  48. ModelBinFromDataReaderPrivate(const DataReader& _dr)
  49. : dr(_dr)
  50. {
  51. }
  52. const DataReader& dr;
  53. };
  54. ModelBinFromDataReader::ModelBinFromDataReader(const DataReader& _dr)
  55. : ModelBin(), d(new ModelBinFromDataReaderPrivate(_dr))
  56. {
  57. }
  58. ModelBinFromDataReader::~ModelBinFromDataReader()
  59. {
  60. delete d;
  61. }
  62. ModelBinFromDataReader::ModelBinFromDataReader(const ModelBinFromDataReader&)
  63. : d(0)
  64. {
  65. }
  66. ModelBinFromDataReader& ModelBinFromDataReader::operator=(const ModelBinFromDataReader&)
  67. {
  68. return *this;
  69. }
  70. Mat ModelBinFromDataReader::load(int w, int type) const
  71. {
  72. Mat m;
  73. if (type == 0)
  74. {
  75. size_t nread;
  76. union
  77. {
  78. struct
  79. {
  80. unsigned char f0;
  81. unsigned char f1;
  82. unsigned char f2;
  83. unsigned char f3;
  84. };
  85. unsigned int tag;
  86. } flag_struct;
  87. nread = d->dr.read(&flag_struct, sizeof(flag_struct));
  88. if (nread != sizeof(flag_struct))
  89. {
  90. NCNN_LOGE("ModelBin read flag_struct failed %zd", nread);
  91. return Mat();
  92. }
  93. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  94. if (flag_struct.tag == 0x01306B47)
  95. {
  96. // half-precision data
  97. size_t align_data_size = alignSize(w * sizeof(unsigned short), 4);
  98. // try reference data
  99. const void* refbuf = 0;
  100. nread = d->dr.reference(align_data_size, &refbuf);
  101. if (nread == align_data_size)
  102. {
  103. m = Mat::from_float16((const unsigned short*)refbuf, w);
  104. }
  105. else
  106. {
  107. std::vector<unsigned short> float16_weights;
  108. float16_weights.resize(align_data_size);
  109. nread = d->dr.read(&float16_weights[0], align_data_size);
  110. if (nread != align_data_size)
  111. {
  112. NCNN_LOGE("ModelBin read float16_weights failed %zd", nread);
  113. return Mat();
  114. }
  115. m = Mat::from_float16(&float16_weights[0], w);
  116. }
  117. return m;
  118. }
  119. else if (flag_struct.tag == 0x000D4B38)
  120. {
  121. // int8 data
  122. size_t align_data_size = alignSize(w, 4);
  123. // try reference data
  124. const void* refbuf = 0;
  125. nread = d->dr.reference(align_data_size, &refbuf);
  126. if (nread == align_data_size)
  127. {
  128. m = Mat(w, (void*)refbuf, (size_t)1u);
  129. }
  130. else
  131. {
  132. std::vector<signed char> int8_weights;
  133. int8_weights.resize(align_data_size);
  134. nread = d->dr.read(&int8_weights[0], align_data_size);
  135. if (nread != align_data_size)
  136. {
  137. NCNN_LOGE("ModelBin read int8_weights failed %zd", nread);
  138. return Mat();
  139. }
  140. m.create(w, (size_t)1u);
  141. if (m.empty())
  142. return m;
  143. memcpy(m.data, &int8_weights[0], w);
  144. }
  145. return m;
  146. }
  147. else if (flag_struct.tag == 0x0002C056)
  148. {
  149. // try reference data
  150. const void* refbuf = 0;
  151. nread = d->dr.reference(w * sizeof(float), &refbuf);
  152. if (nread == w * sizeof(float))
  153. {
  154. m = Mat(w, (void*)refbuf);
  155. }
  156. else
  157. {
  158. m.create(w);
  159. if (m.empty())
  160. return m;
  161. // raw data with extra scaling
  162. nread = d->dr.read(m, w * sizeof(float));
  163. if (nread != w * sizeof(float))
  164. {
  165. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  166. return Mat();
  167. }
  168. }
  169. return m;
  170. }
  171. if (flag != 0)
  172. {
  173. m.create(w);
  174. if (m.empty())
  175. return m;
  176. // quantized data
  177. float quantization_value[256];
  178. nread = d->dr.read(quantization_value, 256 * sizeof(float));
  179. if (nread != 256 * sizeof(float))
  180. {
  181. NCNN_LOGE("ModelBin read quantization_value failed %zd", nread);
  182. return Mat();
  183. }
  184. size_t align_weight_data_size = alignSize(w * sizeof(unsigned char), 4);
  185. std::vector<unsigned char> index_array;
  186. index_array.resize(align_weight_data_size);
  187. nread = d->dr.read(&index_array[0], align_weight_data_size);
  188. if (nread != align_weight_data_size)
  189. {
  190. NCNN_LOGE("ModelBin read index_array failed %zd", nread);
  191. return Mat();
  192. }
  193. float* ptr = m;
  194. for (int i = 0; i < w; i++)
  195. {
  196. ptr[i] = quantization_value[index_array[i]];
  197. }
  198. }
  199. else if (flag_struct.f0 == 0)
  200. {
  201. // try reference data
  202. const void* refbuf = 0;
  203. nread = d->dr.reference(w * sizeof(float), &refbuf);
  204. if (nread == w * sizeof(float))
  205. {
  206. m = Mat(w, (void*)refbuf);
  207. }
  208. else
  209. {
  210. m.create(w);
  211. if (m.empty())
  212. return m;
  213. // raw data
  214. nread = d->dr.read(m, w * sizeof(float));
  215. if (nread != w * sizeof(float))
  216. {
  217. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  218. return Mat();
  219. }
  220. }
  221. }
  222. return m;
  223. }
  224. else if (type == 1)
  225. {
  226. // try reference data
  227. const void* refbuf = 0;
  228. size_t nread = d->dr.reference(w * sizeof(float), &refbuf);
  229. if (nread == w * sizeof(float))
  230. {
  231. m = Mat(w, (void*)refbuf);
  232. }
  233. else
  234. {
  235. m.create(w);
  236. if (m.empty())
  237. return m;
  238. // raw data
  239. size_t nread = d->dr.read(m, w * sizeof(float));
  240. if (nread != w * sizeof(float))
  241. {
  242. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  243. return Mat();
  244. }
  245. }
  246. return m;
  247. }
  248. else
  249. {
  250. NCNN_LOGE("ModelBin load type %d not implemented", type);
  251. return Mat();
  252. }
  253. return Mat();
  254. }
  255. class ModelBinFromMatArrayPrivate
  256. {
  257. public:
  258. ModelBinFromMatArrayPrivate(const Mat* _weights)
  259. : weights(_weights)
  260. {
  261. }
  262. mutable const Mat* weights;
  263. };
  264. ModelBinFromMatArray::ModelBinFromMatArray(const Mat* _weights)
  265. : ModelBin(), d(new ModelBinFromMatArrayPrivate(_weights))
  266. {
  267. }
  268. ModelBinFromMatArray::~ModelBinFromMatArray()
  269. {
  270. delete d;
  271. }
  272. ModelBinFromMatArray::ModelBinFromMatArray(const ModelBinFromMatArray&)
  273. : d(0)
  274. {
  275. }
  276. ModelBinFromMatArray& ModelBinFromMatArray::operator=(const ModelBinFromMatArray&)
  277. {
  278. return *this;
  279. }
  280. Mat ModelBinFromMatArray::load(int /*w*/, int /*type*/) const
  281. {
  282. if (!d->weights)
  283. return Mat();
  284. Mat m = d->weights[0];
  285. d->weights++;
  286. return m;
  287. }
  288. } // namespace ncnn