You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolov4.cpp 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "net.h"
  15. #include <opencv2/core/core.hpp>
  16. #include <opencv2/highgui/highgui.hpp>
  17. #include <opencv2/imgproc/imgproc.hpp>
  18. #include <stdio.h>
  19. #include <vector>
  20. #define YOLOV4_TINY 1 //0 or undef for yolov4
  21. struct Object
  22. {
  23. cv::Rect_<float> rect;
  24. int label;
  25. float prob;
  26. };
  27. static int detect_yolov4(const cv::Mat& bgr, std::vector<Object>& objects)
  28. {
  29. ncnn::Net yolov4;
  30. yolov4.opt.use_vulkan_compute = true;
  31. // original pretrained model from https://github.com/AlexeyAB/darknet
  32. // the ncnn model https://drive.google.com/drive/folders/1YzILvh0SKQPS_lrb33dmGNq7aVTKPWS0?usp=sharing
  33. // the ncnn model https://github.com/nihui/ncnn-assets/tree/master/models
  34. #if YOLOV4_TINY
  35. yolov4.load_param("yolov4-tiny-opt.param");
  36. yolov4.load_model("yolov4-tiny-opt.bin");
  37. const int target_size = 416;
  38. #else
  39. yolov4.load_param("yolov4-opt.param");
  40. yolov4.load_model("yolov4-opt.bin");
  41. const int target_size = 608;
  42. #endif
  43. int img_w = bgr.cols;
  44. int img_h = bgr.rows;
  45. ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR, bgr.cols, bgr.rows, target_size, target_size);
  46. const float mean_vals[3] = {0, 0, 0};
  47. const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f};
  48. in.substract_mean_normalize(mean_vals, norm_vals);
  49. ncnn::Extractor ex = yolov4.create_extractor();
  50. ex.set_num_threads(4);
  51. ex.input("data", in);
  52. ncnn::Mat out;
  53. ex.extract("output", out);
  54. // printf("%d %d %d\n", out.w, out.h, out.c);
  55. objects.clear();
  56. for (int i = 0; i < out.h; i++)
  57. {
  58. const float* values = out.row(i);
  59. Object object;
  60. object.label = values[0];
  61. object.prob = values[1];
  62. object.rect.x = values[2] * img_w;
  63. object.rect.y = values[3] * img_h;
  64. object.rect.width = values[4] * img_w - object.rect.x;
  65. object.rect.height = values[5] * img_h - object.rect.y;
  66. objects.push_back(object);
  67. }
  68. return 0;
  69. }
  70. static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
  71. {
  72. static const char* class_names[] = {"background", "person", "bicycle",
  73. "car", "motorbike", "aeroplane", "bus", "train", "truck",
  74. "boat", "traffic light", "fire hydrant", "stop sign",
  75. "parking meter", "bench", "bird", "cat", "dog", "horse",
  76. "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
  77. "backpack", "umbrella", "handbag", "tie", "suitcase",
  78. "frisbee", "skis", "snowboard", "sports ball", "kite",
  79. "baseball bat", "baseball glove", "skateboard", "surfboard",
  80. "tennis racket", "bottle", "wine glass", "cup", "fork",
  81. "knife", "spoon", "bowl", "banana", "apple", "sandwich",
  82. "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
  83. "cake", "chair", "sofa", "pottedplant", "bed", "diningtable",
  84. "toilet", "tvmonitor", "laptop", "mouse", "remote", "keyboard",
  85. "cell phone", "microwave", "oven", "toaster", "sink",
  86. "refrigerator", "book", "clock", "vase", "scissors",
  87. "teddy bear", "hair drier", "toothbrush"
  88. };
  89. cv::Mat image = bgr.clone();
  90. for (size_t i = 0; i < objects.size(); i++)
  91. {
  92. const Object& obj = objects[i];
  93. fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
  94. obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
  95. cv::rectangle(image, obj.rect, cv::Scalar(255, 0, 0));
  96. char text[256];
  97. sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
  98. int baseLine = 0;
  99. cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  100. int x = obj.rect.x;
  101. int y = obj.rect.y - label_size.height - baseLine;
  102. if (y < 0)
  103. y = 0;
  104. if (x + label_size.width > image.cols)
  105. x = image.cols - label_size.width;
  106. cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
  107. cv::Scalar(255, 255, 255), -1);
  108. cv::putText(image, text, cv::Point(x, y + label_size.height),
  109. cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
  110. }
  111. cv::imshow("image", image);
  112. cv::waitKey(0);
  113. }
  114. int main(int argc, char** argv)
  115. {
  116. if (argc != 2)
  117. {
  118. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  119. return -1;
  120. }
  121. const char* imagepath = argv[1];
  122. cv::Mat m = cv::imread(imagepath, 1);
  123. if (m.empty())
  124. {
  125. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  126. return -1;
  127. }
  128. std::vector<Object> objects;
  129. detect_yolov4(m, objects);
  130. draw_objects(m, objects);
  131. return 0;
  132. }