You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

modelbin.cpp 6.7 kB

8 years ago
8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "modelbin.h"
  15. #include "datareader.h"
  16. #include <string.h>
  17. namespace ncnn {
  18. ModelBin::ModelBin()
  19. {
  20. }
  21. ModelBin::~ModelBin()
  22. {
  23. }
  24. Mat ModelBin::load(int w, int h, int type) const
  25. {
  26. Mat m = load(w * h, type);
  27. if (m.empty())
  28. return m;
  29. return m.reshape(w, h);
  30. }
  31. Mat ModelBin::load(int w, int h, int c, int type) const
  32. {
  33. Mat m = load(w * h * c, type);
  34. if (m.empty())
  35. return m;
  36. return m.reshape(w, h, c);
  37. }
  38. class ModelBinFromDataReaderPrivate
  39. {
  40. public:
  41. ModelBinFromDataReaderPrivate(const DataReader& _dr)
  42. : dr(_dr)
  43. {
  44. }
  45. const DataReader& dr;
  46. };
  47. ModelBinFromDataReader::ModelBinFromDataReader(const DataReader& _dr)
  48. : ModelBin(), d(new ModelBinFromDataReaderPrivate(_dr))
  49. {
  50. }
  51. ModelBinFromDataReader::~ModelBinFromDataReader()
  52. {
  53. delete d;
  54. }
  55. ModelBinFromDataReader::ModelBinFromDataReader(const ModelBinFromDataReader&)
  56. : d(0)
  57. {
  58. }
  59. ModelBinFromDataReader& ModelBinFromDataReader::operator=(const ModelBinFromDataReader&)
  60. {
  61. return *this;
  62. }
  63. Mat ModelBinFromDataReader::load(int w, int type) const
  64. {
  65. if (type == 0)
  66. {
  67. size_t nread;
  68. union
  69. {
  70. struct
  71. {
  72. unsigned char f0;
  73. unsigned char f1;
  74. unsigned char f2;
  75. unsigned char f3;
  76. };
  77. unsigned int tag;
  78. } flag_struct;
  79. nread = d->dr.read(&flag_struct, sizeof(flag_struct));
  80. if (nread != sizeof(flag_struct))
  81. {
  82. NCNN_LOGE("ModelBin read flag_struct failed %zd", nread);
  83. return Mat();
  84. }
  85. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  86. if (flag_struct.tag == 0x01306B47)
  87. {
  88. // half-precision data
  89. size_t align_data_size = alignSize(w * sizeof(unsigned short), 4);
  90. std::vector<unsigned short> float16_weights;
  91. float16_weights.resize(align_data_size);
  92. nread = d->dr.read(float16_weights.data(), align_data_size);
  93. if (nread != align_data_size)
  94. {
  95. NCNN_LOGE("ModelBin read float16_weights failed %zd", nread);
  96. return Mat();
  97. }
  98. return Mat::from_float16(float16_weights.data(), w);
  99. }
  100. else if (flag_struct.tag == 0x000D4B38)
  101. {
  102. // int8 data
  103. size_t align_data_size = alignSize(w, 4);
  104. std::vector<signed char> int8_weights;
  105. int8_weights.resize(align_data_size);
  106. nread = d->dr.read(int8_weights.data(), align_data_size);
  107. if (nread != align_data_size)
  108. {
  109. NCNN_LOGE("ModelBin read int8_weights failed %zd", nread);
  110. return Mat();
  111. }
  112. Mat m(w, (size_t)1u);
  113. if (m.empty())
  114. return m;
  115. memcpy(m.data, int8_weights.data(), w);
  116. return m;
  117. }
  118. else if (flag_struct.tag == 0x0002C056)
  119. {
  120. Mat m(w);
  121. if (m.empty())
  122. return m;
  123. // raw data with extra scaling
  124. nread = d->dr.read(m, w * sizeof(float));
  125. if (nread != w * sizeof(float))
  126. {
  127. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  128. return Mat();
  129. }
  130. return m;
  131. }
  132. Mat m(w);
  133. if (m.empty())
  134. return m;
  135. if (flag != 0)
  136. {
  137. // quantized data
  138. float quantization_value[256];
  139. nread = d->dr.read(quantization_value, 256 * sizeof(float));
  140. if (nread != 256 * sizeof(float))
  141. {
  142. NCNN_LOGE("ModelBin read quantization_value failed %zd", nread);
  143. return Mat();
  144. }
  145. size_t align_weight_data_size = alignSize(w * sizeof(unsigned char), 4);
  146. std::vector<unsigned char> index_array;
  147. index_array.resize(align_weight_data_size);
  148. nread = d->dr.read(index_array.data(), align_weight_data_size);
  149. if (nread != align_weight_data_size)
  150. {
  151. NCNN_LOGE("ModelBin read index_array failed %zd", nread);
  152. return Mat();
  153. }
  154. float* ptr = m;
  155. for (int i = 0; i < w; i++)
  156. {
  157. ptr[i] = quantization_value[index_array[i]];
  158. }
  159. }
  160. else if (flag_struct.f0 == 0)
  161. {
  162. // raw data
  163. nread = d->dr.read(m, w * sizeof(float));
  164. if (nread != w * sizeof(float))
  165. {
  166. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  167. return Mat();
  168. }
  169. }
  170. return m;
  171. }
  172. else if (type == 1)
  173. {
  174. Mat m(w);
  175. if (m.empty())
  176. return m;
  177. // raw data
  178. size_t nread = d->dr.read(m, w * sizeof(float));
  179. if (nread != w * sizeof(float))
  180. {
  181. NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
  182. return Mat();
  183. }
  184. return m;
  185. }
  186. else
  187. {
  188. NCNN_LOGE("ModelBin load type %d not implemented", type);
  189. return Mat();
  190. }
  191. return Mat();
  192. }
  193. class ModelBinFromMatArrayPrivate
  194. {
  195. public:
  196. ModelBinFromMatArrayPrivate(const Mat* _weights)
  197. : weights(_weights)
  198. {
  199. }
  200. mutable const Mat* weights;
  201. };
  202. ModelBinFromMatArray::ModelBinFromMatArray(const Mat* _weights)
  203. : ModelBin(), d(new ModelBinFromMatArrayPrivate(_weights))
  204. {
  205. }
  206. ModelBinFromMatArray::~ModelBinFromMatArray()
  207. {
  208. delete d;
  209. }
  210. ModelBinFromMatArray::ModelBinFromMatArray(const ModelBinFromMatArray&)
  211. : d(0)
  212. {
  213. }
  214. ModelBinFromMatArray& ModelBinFromMatArray::operator=(const ModelBinFromMatArray&)
  215. {
  216. return *this;
  217. }
  218. Mat ModelBinFromMatArray::load(int /*w*/, int /*type*/) const
  219. {
  220. if (!d->weights)
  221. return Mat();
  222. Mat m = d->weights[0];
  223. d->weights++;
  224. return m;
  225. }
  226. } // namespace ncnn