You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rvm.cpp 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #include "net.h"
  2. #if defined(USE_NCNN_SIMPLEOCV)
  3. #include "simpleocv.h"
  4. #else
  5. #include <opencv2/core/core.hpp>
  6. #include <opencv2/highgui/highgui.hpp>
  7. #include <opencv2/imgproc/imgproc.hpp>
  8. #endif
  9. #include <stdlib.h>
  10. #include <float.h>
  11. #include <stdio.h>
  12. #include <vector>
  13. static void draw_objects(const cv::Mat& bgr, const cv::Mat& fgr, const cv::Mat& pha)
  14. {
  15. cv::Mat fgr8U;
  16. fgr.convertTo(fgr8U, CV_8UC3, 255.0, 0);
  17. cv::Mat pha8U;
  18. pha.convertTo(pha8U, CV_8UC1, 255.0, 0);
  19. cv::Mat comp;
  20. cv::resize(bgr, comp, pha.size(), 0, 0, 1);
  21. for (int i = 0; i < pha8U.rows; i++)
  22. {
  23. for (int j = 0; j < pha8U.cols; j++)
  24. {
  25. uchar data = pha8U.at<uchar>(i, j);
  26. float alpha = (float)data / 255;
  27. comp.at<cv::Vec3b>(i, j)[0] = fgr8U.at<cv::Vec3b>(i, j)[0] * alpha + (1 - alpha) * 155;
  28. comp.at<cv::Vec3b>(i, j)[1] = fgr8U.at<cv::Vec3b>(i, j)[1] * alpha + (1 - alpha) * 255;
  29. comp.at<cv::Vec3b>(i, j)[2] = fgr8U.at<cv::Vec3b>(i, j)[2] * alpha + (1 - alpha) * 120;
  30. }
  31. }
  32. cv::imshow("pha", pha8U);
  33. cv::imshow("fgr", fgr8U);
  34. cv::imshow("comp", comp);
  35. cv::waitKey(0);
  36. }
  37. static int detect_rvm(const cv::Mat& bgr, cv::Mat& pha, cv::Mat& fgr)
  38. {
  39. const float downsample_ratio = 0.5f;
  40. const int target_width = 512;
  41. const int target_height = 512;
  42. ncnn::Net net;
  43. net.opt.use_vulkan_compute = false;
  44. //original pretrained model from https://github.com/PeterL1n/RobustVideoMatting
  45. //ncnn model https://pan.baidu.com/s/11iEY2RGfzWFtce8ue7T3JQ password: d9t6
  46. net.load_param("rvm_512.param");
  47. net.load_model("rvm_512.bin");
  48. //if you use another input size,pleaze change input shape
  49. ncnn::Mat r1i = ncnn::Mat(128, 128, 16);
  50. ncnn::Mat r2i = ncnn::Mat(64, 64, 20);
  51. ncnn::Mat r3i = ncnn::Mat(32, 32, 40);
  52. ncnn::Mat r4i = ncnn::Mat(16, 16, 64);
  53. r1i.fill(0.0f);
  54. r2i.fill(0.0f);
  55. r3i.fill(0.0f);
  56. r4i.fill(0.0f);
  57. ncnn::Extractor ex = net.create_extractor();
  58. const float mean_vals1[3] = {123.675f, 116.28f, 103.53f};
  59. const float norm_vals1[3] = {0.01712475f, 0.0175f, 0.01742919f};
  60. const float mean_vals2[3] = {0, 0, 0};
  61. const float norm_vals2[3] = {1 / 255.0, 1 / 255.0, 1 / 255.0};
  62. ncnn::Mat ncnn_in2 = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, bgr.cols, bgr.rows, target_width, target_height);
  63. ncnn::Mat ncnn_in1 = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, bgr.cols, bgr.rows, target_width * downsample_ratio, target_height * downsample_ratio);
  64. ncnn_in1.substract_mean_normalize(mean_vals1, norm_vals1);
  65. ncnn_in2.substract_mean_normalize(mean_vals2, norm_vals2);
  66. ex.input("src1", ncnn_in1);
  67. ex.input("src2", ncnn_in2);
  68. ex.input("r1i", r1i);
  69. ex.input("r2i", r2i);
  70. ex.input("r3i", r3i);
  71. ex.input("r4i", r4i);
  72. //if use video matting,these output will be input of next infer
  73. ex.extract("r4o", r4i);
  74. ex.extract("r3o", r3i);
  75. ex.extract("r2o", r2i);
  76. ex.extract("r1o", r1i);
  77. ncnn::Mat pha_;
  78. ex.extract("pha", pha_);
  79. ncnn::Mat fgr_;
  80. ex.extract("fgr", fgr_);
  81. cv::Mat cv_pha = cv::Mat(pha_.h, pha_.w, CV_32FC1, (float*)pha_.data);
  82. cv::Mat cv_fgr = cv::Mat(fgr_.h, fgr_.w, CV_32FC3);
  83. float* fgr_data = (float*)fgr_.data;
  84. for (int i = 0; i < fgr_.h; i++)
  85. {
  86. for (int j = 0; j < fgr_.w; j++)
  87. {
  88. cv_fgr.at<cv::Vec3f>(i, j)[2] = fgr_data[0 * fgr_.h * fgr_.w + i * fgr_.w + j];
  89. cv_fgr.at<cv::Vec3f>(i, j)[1] = fgr_data[1 * fgr_.h * fgr_.w + i * fgr_.w + j];
  90. cv_fgr.at<cv::Vec3f>(i, j)[0] = fgr_data[2 * fgr_.h * fgr_.w + i * fgr_.w + j];
  91. }
  92. }
  93. cv_pha.copyTo(pha);
  94. cv_fgr.copyTo(fgr);
  95. return 0;
  96. }
  97. int main(int argc, char** argv)
  98. {
  99. if (argc != 2)
  100. {
  101. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  102. return -1;
  103. }
  104. const char* imagepath = argv[1];
  105. cv::Mat m = cv::imread(imagepath, 1);
  106. if (m.empty())
  107. {
  108. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  109. return -1;
  110. }
  111. cv::Mat fgr, pha;
  112. detect_rvm(m, pha, fgr);
  113. draw_objects(m, fgr, pha);
  114. return 0;
  115. }