diff --git a/tools/ncnnoptimize.cpp b/tools/ncnnoptimize.cpp index 3c6eadbb3..8242607dc 100644 --- a/tools/ncnnoptimize.cpp +++ b/tools/ncnnoptimize.cpp @@ -98,7 +98,7 @@ public: int fprintf_param_int_array(int id, const ncnn::Mat& m, FILE* pp); int fprintf_param_float_array(int id, const ncnn::Mat& m, FILE* pp); - int fwrite_weight_tag(int tag, FILE* bp); + int fwrite_weight_tag_data(int tag, const ncnn::Mat& data, FILE* bp); int fwrite_weight_data(const ncnn::Mat& data, FILE* bp); int save(const char* parampath, const char* binpath); @@ -1158,16 +1158,21 @@ int NetOptimize::fprintf_param_float_array(int id, const ncnn::Mat& m, FILE* pp) return 0; } -int NetOptimize::fwrite_weight_tag(int tag, FILE* bp) +int NetOptimize::fwrite_weight_tag_data(int tag, const ncnn::Mat& data, FILE* bp) { + ncnn::Mat data_flattened = data.reshape(data.w * data.h * data.c); if (storage_type == 1 && tag == 0) { tag = 0x01306B47; // fp16 magic fwrite(&tag, sizeof(int), 1, bp); + ncnn::Mat data_flattened_fp16; + ncnn::cast_float32_to_float16(data_flattened, data_flattened_fp16); + fwrite(data_flattened_fp16.data, data_flattened_fp16.elemsize, data_flattened_fp16.w, bp); } else { fwrite(&tag, sizeof(int), 1, bp); + fwrite(data_flattened.data, data_flattened.elemsize, data_flattened.w, bp); } return 0; } @@ -1175,16 +1180,7 @@ int NetOptimize::fwrite_weight_tag(int tag, FILE* bp) int NetOptimize::fwrite_weight_data(const ncnn::Mat& data, FILE* bp) { ncnn::Mat data_flattened = data.reshape(data.w * data.h * data.c); - if (storage_type == 1) - { - ncnn::Mat data_flattened_fp16; - ncnn::cast_float32_to_float16(data_flattened, data_flattened_fp16); - fwrite(data_flattened_fp16.data, data_flattened_fp16.elemsize, data_flattened_fp16.w, bp); - } - else - { - fwrite(data_flattened.data, data_flattened.elemsize, data_flattened.w, bp); - } + fwrite(data_flattened.data, data_flattened.elemsize, data_flattened.w, bp); return 0; } @@ -1322,8 +1318,7 @@ int NetOptimize::save(const char* parampath, const char* binpath) fprintf_param_value(" 9=%d", activation_type) { if (!op->activation_params.empty()) fprintf_param_int_array(10, op->activation_params, pp); } - fwrite_weight_tag(0, bp); - fwrite_weight_data(op->weight_data, bp); + fwrite_weight_tag_data(0, op->weight_data, bp); fwrite_weight_data(op->bias_data, bp); } else if (layer->type == "ConvolutionDepthWise") @@ -1347,8 +1342,7 @@ int NetOptimize::save(const char* parampath, const char* binpath) fprintf_param_value(" 9=%d", activation_type) { if (!op->activation_params.empty()) fprintf_param_int_array(10, op->activation_params, pp); } - fwrite_weight_tag(0, bp); - fwrite_weight_data(op->weight_data, bp); + fwrite_weight_tag_data(0, op->weight_data, bp); fwrite_weight_data(op->bias_data, bp); } else if (layer->type == "Crop") @@ -1382,8 +1376,7 @@ int NetOptimize::save(const char* parampath, const char* binpath) fprintf_param_value(" 9=%d", activation_type) { if (!op->activation_params.empty()) fprintf_param_int_array(10, op->activation_params, pp); } - fwrite_weight_tag(0, bp); - fwrite_weight_data(op->weight_data, bp); + fwrite_weight_tag_data(0, op->weight_data, bp); fwrite_weight_data(op->bias_data, bp); } else if (layer->type == "DeconvolutionDepthWise") @@ -1406,8 +1399,7 @@ int NetOptimize::save(const char* parampath, const char* binpath) fprintf_param_value(" 9=%d", activation_type) { if (!op->activation_params.empty()) fprintf_param_int_array(10, op->activation_params, pp); } - fwrite_weight_tag(0, bp); - fwrite_weight_data(op->weight_data, bp); + fwrite_weight_tag_data(0, op->weight_data, bp); fwrite_weight_data(op->bias_data, bp); } else if (layer->type == "DetectionOutput") @@ -1468,8 +1460,7 @@ int NetOptimize::save(const char* parampath, const char* binpath) fprintf_param_value(" 9=%d", activation_type) { if (!op->activation_params.empty()) fprintf_param_int_array(10, op->activation_params, pp); } - fwrite_weight_tag(0, bp); - fwrite_weight_data(op->weight_data, bp); + fwrite_weight_tag_data(0, op->weight_data, bp); fwrite_weight_data(op->bias_data, bp); } else if (layer->type == "Input")