You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

new-model-load-api.md 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. ## current model load api
  2. ### Cons
  3. #### long and awful code
  4. #### two functions
  5. #### deal float32 float16 quantized-u8
  6. #### deal alignment size
  7. ```cpp
  8. #if NCNN_STDIO
  9. int Convolution::load_model(FILE* binfp)
  10. {
  11. int nread;
  12. union
  13. {
  14. struct
  15. {
  16. unsigned char f0;
  17. unsigned char f1;
  18. unsigned char f2;
  19. unsigned char f3;
  20. };
  21. unsigned int tag;
  22. } flag_struct;
  23. nread = fread(&flag_struct, sizeof(flag_struct), 1, binfp);
  24. if (nread != 1)
  25. {
  26. fprintf(stderr, "Convolution read flag_struct failed %d\n", nread);
  27. return -1;
  28. }
  29. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  30. weight_data.create(weight_data_size);
  31. if (weight_data.empty())
  32. return -100;
  33. if (flag_struct.tag == 0x01306B47)
  34. {
  35. // half-precision weight data
  36. int align_weight_data_size = alignSize(weight_data_size * sizeof(unsigned short), 4);
  37. std::vector<unsigned short> float16_weights;
  38. float16_weights.resize(align_weight_data_size);
  39. nread = fread(float16_weights.data(), align_weight_data_size, 1, binfp);
  40. if (nread != 1)
  41. {
  42. fprintf(stderr, "Convolution read float16_weights failed %d\n", nread);
  43. return -1;
  44. }
  45. weight_data = Mat::from_float16(float16_weights.data(), weight_data_size);
  46. if (weight_data.empty())
  47. return -100;
  48. }
  49. else if (flag != 0)
  50. {
  51. // quantized weight data
  52. float quantization_value[256];
  53. nread = fread(quantization_value, 256 * sizeof(float), 1, binfp);
  54. if (nread != 1)
  55. {
  56. fprintf(stderr, "Convolution read quantization_value failed %d\n", nread);
  57. return -1;
  58. }
  59. int align_weight_data_size = alignSize(weight_data_size * sizeof(unsigned char), 4);
  60. std::vector<unsigned char> index_array;
  61. index_array.resize(align_weight_data_size);
  62. nread = fread(index_array.data(), align_weight_data_size, 1, binfp);
  63. if (nread != 1)
  64. {
  65. fprintf(stderr, "Convolution read index_array failed %d\n", nread);
  66. return -1;
  67. }
  68. float* weight_data_ptr = weight_data;
  69. for (int i = 0; i < weight_data_size; i++)
  70. {
  71. weight_data_ptr[i] = quantization_value[ index_array[i] ];
  72. }
  73. }
  74. else if (flag_struct.f0 == 0)
  75. {
  76. // raw weight data
  77. nread = fread(weight_data, weight_data_size * sizeof(float), 1, binfp);
  78. if (nread != 1)
  79. {
  80. fprintf(stderr, "Convolution read weight_data failed %d\n", nread);
  81. return -1;
  82. }
  83. }
  84. if (bias_term)
  85. {
  86. bias_data.create(num_output);
  87. if (bias_data.empty())
  88. return -100;
  89. nread = fread(bias_data, num_output * sizeof(float), 1, binfp);
  90. if (nread != 1)
  91. {
  92. fprintf(stderr, "Convolution read bias_data failed %d\n", nread);
  93. return -1;
  94. }
  95. }
  96. return 0;
  97. }
  98. #endif // NCNN_STDIO
  99. int Convolution::load_model(const unsigned char*& mem)
  100. {
  101. union
  102. {
  103. struct
  104. {
  105. unsigned char f0;
  106. unsigned char f1;
  107. unsigned char f2;
  108. unsigned char f3;
  109. };
  110. unsigned int tag;
  111. } flag_struct;
  112. memcpy(&flag_struct, mem, sizeof(flag_struct));
  113. mem += sizeof(flag_struct);
  114. unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
  115. if (flag_struct.tag == 0x01306B47)
  116. {
  117. // half-precision weight data
  118. weight_data = Mat::from_float16((unsigned short*)mem, weight_data_size);
  119. mem += alignSize(weight_data_size * sizeof(unsigned short), 4);
  120. if (weight_data.empty())
  121. return -100;
  122. }
  123. else if (flag != 0)
  124. {
  125. // quantized weight data
  126. const float* quantization_value = (const float*)mem;
  127. mem += 256 * sizeof(float);
  128. const unsigned char* index_array = (const unsigned char*)mem;
  129. mem += alignSize(weight_data_size * sizeof(unsigned char), 4);
  130. weight_data.create(weight_data_size);
  131. if (weight_data.empty())
  132. return -100;
  133. float* weight_data_ptr = weight_data;
  134. for (int i = 0; i < weight_data_size; i++)
  135. {
  136. weight_data_ptr[i] = quantization_value[ index_array[i] ];
  137. }
  138. }
  139. else if (flag_struct.f0 == 0)
  140. {
  141. // raw weight data
  142. weight_data = Mat(weight_data_size, (float*)mem);
  143. mem += weight_data_size * sizeof(float);
  144. }
  145. if (bias_term)
  146. {
  147. bias_data = Mat(num_output, (float*)mem);
  148. mem += num_output * sizeof(float);
  149. }
  150. return 0;
  151. }
  152. ```
  153. ## new model load api proposed
  154. ### Pros
  155. #### clean and simple api
  156. #### element type detection
  157. ```cpp
  158. int Convolution::load_model(const ModelBin& mb)
  159. {
  160. // auto detect element type
  161. weight_data = mb.load(weight_data_size, 0);
  162. if (weight_data.empty())
  163. return -100;
  164. if (bias_term)
  165. {
  166. // certain type specified
  167. bias_data = mb.load(num_output, 1);
  168. if (bias_data.empty())
  169. return -100;
  170. }
  171. return 0;
  172. }
  173. ```