You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rvm.cpp 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. // Copyright 2025 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // ncnn model exported from https://github.com/PeterL1n/RobustVideoMatting
  4. //
  5. // import torch
  6. // from torch import nn
  7. // from model import MattingNetwork
  8. // from model.fast_guided_filter import FastGuidedFilterRefiner
  9. // from model.deep_guided_filter import DeepGuidedFilterRefiner
  10. //
  11. // class Model(nn.Module):
  12. // def __init__(self):
  13. // super().__init__()
  14. //
  15. // self.rvm = MattingNetwork('mobilenetv3').eval()
  16. // self.rvm.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  17. //
  18. // self.refiner_deep = DeepGuidedFilterRefiner()
  19. // self.refiner_fast = FastGuidedFilterRefiner()
  20. //
  21. // def forward_first_frame(self, src):
  22. // return self.rvm(src)
  23. //
  24. // def forward(self, src, src_sm, r1, r2, r3, r4):
  25. //
  26. // f1, f2, f3, f4 = self.rvm.backbone(src_sm)
  27. // f4 = self.rvm.aspp(f4)
  28. // hid, *rec = self.rvm.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
  29. //
  30. // # downsample
  31. // fgr_residual, pha = self.rvm.project_mat(hid).split([3, 1], dim=-3)
  32. // fgr = fgr_residual + src_sm
  33. //
  34. // # downsample + refiner_deep
  35. // fgr_residual_deep, pha_deep = self.refiner_deep(src, src_sm, fgr_residual, pha, hid)
  36. // fgr_deep = fgr_residual_deep + src
  37. //
  38. // # downsample + refiner_fast
  39. // fgr_residual_fast, pha_fast = self.refiner_fast(src, src_sm, fgr_residual, pha, hid)
  40. // fgr_fast = fgr_residual_fast + src
  41. //
  42. // # downsample + segmentation
  43. // seg = self.rvm.project_seg(hid)
  44. //
  45. // return fgr, pha, fgr_deep, pha_deep, fgr_fast, pha_fast, seg, *rec
  46. //
  47. // import pnnx
  48. //
  49. // model = Model().eval()
  50. //
  51. // x = torch.rand(1, 3, 512, 512)
  52. // x2 = torch.rand(1, 3, 256, 256)
  53. // x2_hr = torch.rand(1, 3, 1024, 1024)
  54. //
  55. // # generate feats via forward_first_frame, with different shapes
  56. // fgr, pha, r1, r2, r3, r4 = model.forward_first_frame(x)
  57. // fgr2, pha2, r12, r22, r32, r42 = model.forward_first_frame(x2)
  58. //
  59. // # export with dynamic shape
  60. // pnnx.export(model, "rvm_mobilenetv3.pt", (x, x, r1, r2, r3, r4), (x2_hr, x2, r12, r22, r32, r42))
  61. //
  62. // and then fix refiner_fast fp16 overflow issue in ncnn.param via appending 31=1 layer feat mask
  63. //
  64. // BinaryOp div_58 2 1 401 399 402 0=3 31=1
  65. //
  66. #include "net.h"
  67. #if defined(USE_NCNN_SIMPLEOCV)
  68. #include "simpleocv.h"
  69. #else
  70. #include <opencv2/core/core.hpp>
  71. #include <opencv2/highgui/highgui.hpp>
  72. #include <opencv2/imgproc/imgproc.hpp>
  73. #endif
  74. static int detect_rvm(const cv::Mat& bgr, cv::Mat& fgr, cv::Mat& pha, cv::Mat& seg)
  75. {
  76. ncnn::Net rvm;
  77. rvm.opt.use_vulkan_compute = true;
  78. // https://github.com/nihui/ncnn-android-rvm/tree/master/app/src/main/assets
  79. // you shall also change r1,r2,r3,r4 shape below when model changed
  80. if (rvm.load_param("rvm_mobilenetv3.ncnn.param"))
  81. exit(-1);
  82. if (rvm.load_model("rvm_mobilenetv3.ncnn.bin"))
  83. exit(-1);
  84. // if (rvm.load_param("rvm_resnet50.ncnn.param"))
  85. // exit(-1);
  86. // if (rvm.load_model("rvm_resnet50.ncnn.bin"))
  87. // exit(-1);
  88. const int w = bgr.cols;
  89. const int h = bgr.rows;
  90. const int target_size = 512;
  91. const int max_stride = 16;
  92. bool refine_deep = true;
  93. // bool refine_fast = true;
  94. const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f};
  95. ncnn::Mat in_pad;
  96. ncnn::Mat in_small_pad;
  97. int wpad = 0;
  98. int hpad = 0;
  99. bool downsample = std::max(w, h) > target_size;
  100. if (downsample)
  101. {
  102. // letterbox pad to multiple of max_stride
  103. int w2 = w;
  104. int h2 = h;
  105. float scale = 1.f;
  106. if (w > h)
  107. {
  108. scale = (float)target_size / w;
  109. w2 = target_size;
  110. h2 = h2 * scale;
  111. }
  112. else
  113. {
  114. scale = (float)target_size / h;
  115. h2 = target_size;
  116. w2 = w2 * scale;
  117. }
  118. ncnn::Mat in_small = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, w, h, w2, h2);
  119. // letterbox pad to target_size rectangle
  120. int w2pad = (w2 + max_stride - 1) / max_stride * max_stride - w2;
  121. int h2pad = (h2 + max_stride - 1) / max_stride * max_stride - h2;
  122. ncnn::copy_make_border(in_small, in_small_pad, h2pad / 2, h2pad - h2pad / 2, w2pad / 2, w2pad - w2pad / 2, ncnn::BORDER_CONSTANT, 114.f);
  123. in_small_pad.substract_mean_normalize(0, norm_vals);
  124. int w3 = w;
  125. int h3 = h;
  126. if (w > h)
  127. {
  128. w3 = w;
  129. h3 = in_small_pad.h / scale;
  130. wpad = 0;
  131. hpad = h3 - h;
  132. }
  133. else
  134. {
  135. h3 = h;
  136. w3 = in_small_pad.w / scale;
  137. wpad = w3 - w;
  138. hpad = 0;
  139. }
  140. ncnn::Mat in = ncnn::Mat::from_pixels(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, w, h);
  141. ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);
  142. in_pad.substract_mean_normalize(0, norm_vals);
  143. }
  144. else
  145. {
  146. ncnn::Mat in = ncnn::Mat::from_pixels(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, w, h);
  147. // letterbox pad to target_size rectangle
  148. wpad = (w + max_stride - 1) / max_stride * max_stride - w;
  149. hpad = (h + max_stride - 1) / max_stride * max_stride - h;
  150. ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);
  151. in_pad.substract_mean_normalize(0, norm_vals);
  152. in_small_pad = in_pad;
  153. }
  154. // rvm_mobilenetv3
  155. ncnn::Mat r1(in_small_pad.w / 2, in_small_pad.h / 2, 16);
  156. ncnn::Mat r2(in_small_pad.w / 4, in_small_pad.h / 4, 20);
  157. ncnn::Mat r3(in_small_pad.w / 8, in_small_pad.h / 8, 40);
  158. ncnn::Mat r4(in_small_pad.w / 16, in_small_pad.h / 16, 64);
  159. // rvm_resnet50
  160. // ncnn::Mat r1(in_small_pad.w / 2, in_small_pad.h / 2, 16);
  161. // ncnn::Mat r2(in_small_pad.w / 4, in_small_pad.h / 4, 32);
  162. // ncnn::Mat r3(in_small_pad.w / 8, in_small_pad.h / 8, 64);
  163. // ncnn::Mat r4(in_small_pad.w / 16, in_small_pad.h / 16, 128);
  164. r1.fill(0.f);
  165. r2.fill(0.f);
  166. r3.fill(0.f);
  167. r4.fill(0.f);
  168. ncnn::Extractor ex = rvm.create_extractor();
  169. ex.input("in0", in_pad);
  170. ex.input("in1", in_small_pad);
  171. ex.input("in2", r1);
  172. ex.input("in3", r2);
  173. ex.input("in4", r3);
  174. ex.input("in5", r4);
  175. ncnn::Mat out_fgr;
  176. ncnn::Mat out_pha;
  177. if (downsample)
  178. {
  179. if (refine_deep)
  180. {
  181. // downsample + refine deep
  182. ex.extract("out2", out_fgr);
  183. ex.extract("out3", out_pha);
  184. }
  185. else // if (refine_fast)
  186. {
  187. // downsample + refine fast
  188. ex.extract("out4", out_fgr);
  189. ex.extract("out5", out_pha);
  190. }
  191. }
  192. else
  193. {
  194. // no downsample
  195. ex.extract("out0", out_fgr);
  196. ex.extract("out1", out_pha);
  197. }
  198. ncnn::Mat out_seg;
  199. // segmentation
  200. ex.extract("out6", out_seg);
  201. // feats
  202. ex.extract("out7", r1);
  203. ex.extract("out8", r2);
  204. ex.extract("out9", r3);
  205. ex.extract("out10", r4);
  206. const float denorm_vals[3] = {255.f, 255.f, 255.f};
  207. out_fgr.substract_mean_normalize(0, denorm_vals);
  208. fgr.create(out_fgr.h, out_fgr.w, CV_8UC3);
  209. out_fgr.to_pixels(fgr.data, ncnn::Mat::PIXEL_RGB2BGR);
  210. out_pha.substract_mean_normalize(0, denorm_vals);
  211. pha.create(out_pha.h, out_pha.w, CV_8UC1);
  212. out_pha.to_pixels(pha.data, ncnn::Mat::PIXEL_GRAY);
  213. out_seg.substract_mean_normalize(0, denorm_vals);
  214. seg.create(in_pad.h, in_pad.w, CV_8UC1);
  215. out_seg.to_pixels_resize(seg.data, ncnn::Mat::PIXEL_GRAY, in_pad.w, in_pad.h);
  216. // cut letterbox pad
  217. fgr = fgr(cv::Rect(wpad / 2, hpad / 2, w, h));
  218. pha = pha(cv::Rect(wpad / 2, hpad / 2, w, h));
  219. seg = seg(cv::Rect(wpad / 2, hpad / 2, w, h));
  220. return 0;
  221. }
  222. static void draw_objects(const cv::Mat& bgr, const cv::Mat& fgr, const cv::Mat& pha, const cv::Mat& seg)
  223. {
  224. const int w = bgr.cols;
  225. const int h = bgr.rows;
  226. // composite
  227. cv::Mat comp(h, w, CV_8UC3);
  228. for (int y = 0; y < h; y++)
  229. {
  230. const uchar* pf = fgr.ptr<const uchar>(y);
  231. const uchar* pa = pha.ptr<const uchar>(y);
  232. uchar* p = comp.ptr<uchar>(y);
  233. for (int x = 0; x < w; x++)
  234. {
  235. const float alpha = pa[0] / 255.f;
  236. p[0] = cv::saturate_cast<uchar>(pf[0] * alpha + (1 - alpha) * 155);
  237. p[1] = cv::saturate_cast<uchar>(pf[1] * alpha + (1 - alpha) * 255);
  238. p[2] = cv::saturate_cast<uchar>(pf[2] * alpha + (1 - alpha) * 120);
  239. pf += 3;
  240. pa += 1;
  241. p += 3;
  242. }
  243. }
  244. // composite seg
  245. cv::Mat comp_seg(h, w, CV_8UC3);
  246. for (int y = 0; y < h; y++)
  247. {
  248. const uchar* pb = bgr.ptr<const uchar>(y);
  249. const uchar* ps = seg.ptr<const uchar>(y);
  250. uchar* p = comp_seg.ptr<uchar>(y);
  251. for (int x = 0; x < w; x++)
  252. {
  253. const float alpha = ps[0] / 255.f;
  254. p[0] = cv::saturate_cast<uchar>(pb[0] * alpha + (1 - alpha) * 155);
  255. p[1] = cv::saturate_cast<uchar>(pb[1] * alpha + (1 - alpha) * 255);
  256. p[2] = cv::saturate_cast<uchar>(pb[2] * alpha + (1 - alpha) * 120);
  257. pb += 3;
  258. ps += 1;
  259. p += 3;
  260. }
  261. }
  262. cv::imshow("comp", comp);
  263. cv::imshow("comp_seg", comp_seg);
  264. cv::waitKey(0);
  265. }
  266. int main(int argc, char** argv)
  267. {
  268. if (argc != 2)
  269. {
  270. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  271. return -1;
  272. }
  273. const char* imagepath = argv[1];
  274. cv::Mat m = cv::imread(imagepath, 1);
  275. if (m.empty())
  276. {
  277. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  278. return -1;
  279. }
  280. cv::Mat fgr;
  281. cv::Mat pha;
  282. cv::Mat seg;
  283. detect_rvm(m, fgr, pha, seg);
  284. draw_objects(m, fgr, pha, seg);
  285. return 0;
  286. }