You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rvm.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. // ncnn model exported from https://github.com/PeterL1n/RobustVideoMatting
  15. //
  16. // import torch
  17. // from torch import nn
  18. // from model import MattingNetwork
  19. // from model.fast_guided_filter import FastGuidedFilterRefiner
  20. // from model.deep_guided_filter import DeepGuidedFilterRefiner
  21. //
  22. // class Model(nn.Module):
  23. // def __init__(self):
  24. // super().__init__()
  25. //
  26. // self.rvm = MattingNetwork('mobilenetv3').eval()
  27. // self.rvm.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  28. //
  29. // self.refiner_deep = DeepGuidedFilterRefiner()
  30. // self.refiner_fast = FastGuidedFilterRefiner()
  31. //
  32. // def forward_first_frame(self, src):
  33. // return self.rvm(src)
  34. //
  35. // def forward(self, src, src_sm, r1, r2, r3, r4):
  36. //
  37. // f1, f2, f3, f4 = self.rvm.backbone(src_sm)
  38. // f4 = self.rvm.aspp(f4)
  39. // hid, *rec = self.rvm.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
  40. //
  41. // # downsample
  42. // fgr_residual, pha = self.rvm.project_mat(hid).split([3, 1], dim=-3)
  43. // fgr = fgr_residual + src_sm
  44. //
  45. // # downsample + refiner_deep
  46. // fgr_residual_deep, pha_deep = self.refiner_deep(src, src_sm, fgr_residual, pha, hid)
  47. // fgr_deep = fgr_residual_deep + src
  48. //
  49. // # downsample + refiner_fast
  50. // fgr_residual_fast, pha_fast = self.refiner_fast(src, src_sm, fgr_residual, pha, hid)
  51. // fgr_fast = fgr_residual_fast + src
  52. //
  53. // # downsample + segmentation
  54. // seg = self.rvm.project_seg(hid)
  55. //
  56. // return fgr, pha, fgr_deep, pha_deep, fgr_fast, pha_fast, seg, *rec
  57. //
  58. // import pnnx
  59. //
  60. // model = Model().eval()
  61. //
  62. // x = torch.rand(1, 3, 512, 512)
  63. // x2 = torch.rand(1, 3, 256, 256)
  64. // x2_hr = torch.rand(1, 3, 1024, 1024)
  65. //
  66. // # generate feats via forward_first_frame, with different shapes
  67. // fgr, pha, r1, r2, r3, r4 = model.forward_first_frame(x)
  68. // fgr2, pha2, r12, r22, r32, r42 = model.forward_first_frame(x2)
  69. //
  70. // # export with dynamic shape
  71. // pnnx.export(model, "rvm_mobilenetv3.pt", (x, x, r1, r2, r3, r4), (x2_hr, x2, r12, r22, r32, r42))
  72. //
  73. // and then fix refiner_fast fp16 overflow issue in ncnn.param via appending 31=1 layer feat mask
  74. //
  75. // BinaryOp div_58 2 1 401 399 402 0=3 31=1
  76. //
  77. #include "net.h"
  78. #if defined(USE_NCNN_SIMPLEOCV)
  79. #include "simpleocv.h"
  80. #else
  81. #include <opencv2/core/core.hpp>
  82. #include <opencv2/highgui/highgui.hpp>
  83. #include <opencv2/imgproc/imgproc.hpp>
  84. #endif
  85. static int detect_rvm(const cv::Mat& bgr, cv::Mat& fgr, cv::Mat& pha, cv::Mat& seg)
  86. {
  87. ncnn::Net rvm;
  88. rvm.opt.use_vulkan_compute = true;
  89. // https://github.com/nihui/ncnn-android-rvm/tree/master/app/src/main/assets
  90. // you shall also change r1,r2,r3,r4 shape below when model changed
  91. if (rvm.load_param("rvm_mobilenetv3.ncnn.param"))
  92. exit(-1);
  93. if (rvm.load_model("rvm_mobilenetv3.ncnn.bin"))
  94. exit(-1);
  95. // if (rvm.load_param("rvm_resnet50.ncnn.param"))
  96. // exit(-1);
  97. // if (rvm.load_model("rvm_resnet50.ncnn.bin"))
  98. // exit(-1);
  99. const int w = bgr.cols;
  100. const int h = bgr.rows;
  101. const int target_size = 512;
  102. const int max_stride = 16;
  103. bool refine_deep = true;
  104. // bool refine_fast = true;
  105. const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f};
  106. ncnn::Mat in_pad;
  107. ncnn::Mat in_small_pad;
  108. int wpad = 0;
  109. int hpad = 0;
  110. bool downsample = std::max(w, h) > target_size;
  111. if (downsample)
  112. {
  113. // letterbox pad to multiple of max_stride
  114. int w2 = w;
  115. int h2 = h;
  116. float scale = 1.f;
  117. if (w > h)
  118. {
  119. scale = (float)target_size / w;
  120. w2 = target_size;
  121. h2 = h2 * scale;
  122. }
  123. else
  124. {
  125. scale = (float)target_size / h;
  126. h2 = target_size;
  127. w2 = w2 * scale;
  128. }
  129. ncnn::Mat in_small = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, w, h, w2, h2);
  130. // letterbox pad to target_size rectangle
  131. int w2pad = (w2 + max_stride - 1) / max_stride * max_stride - w2;
  132. int h2pad = (h2 + max_stride - 1) / max_stride * max_stride - h2;
  133. ncnn::copy_make_border(in_small, in_small_pad, h2pad / 2, h2pad - h2pad / 2, w2pad / 2, w2pad - w2pad / 2, ncnn::BORDER_CONSTANT, 114.f);
  134. in_small_pad.substract_mean_normalize(0, norm_vals);
  135. int w3 = w;
  136. int h3 = h;
  137. if (w > h)
  138. {
  139. w3 = w;
  140. h3 = in_small_pad.h / scale;
  141. wpad = 0;
  142. hpad = h3 - h;
  143. }
  144. else
  145. {
  146. h3 = h;
  147. w3 = in_small_pad.w / scale;
  148. wpad = w3 - w;
  149. hpad = 0;
  150. }
  151. ncnn::Mat in = ncnn::Mat::from_pixels(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, w, h);
  152. ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);
  153. in_pad.substract_mean_normalize(0, norm_vals);
  154. }
  155. else
  156. {
  157. ncnn::Mat in = ncnn::Mat::from_pixels(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, w, h);
  158. // letterbox pad to target_size rectangle
  159. wpad = (w + max_stride - 1) / max_stride * max_stride - w;
  160. hpad = (h + max_stride - 1) / max_stride * max_stride - h;
  161. ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);
  162. in_pad.substract_mean_normalize(0, norm_vals);
  163. in_small_pad = in_pad;
  164. }
  165. // rvm_mobilenetv3
  166. ncnn::Mat r1(in_small_pad.w / 2, in_small_pad.h / 2, 16);
  167. ncnn::Mat r2(in_small_pad.w / 4, in_small_pad.h / 4, 20);
  168. ncnn::Mat r3(in_small_pad.w / 8, in_small_pad.h / 8, 40);
  169. ncnn::Mat r4(in_small_pad.w / 16, in_small_pad.h / 16, 64);
  170. // rvm_resnet50
  171. // ncnn::Mat r1(in_small_pad.w / 2, in_small_pad.h / 2, 16);
  172. // ncnn::Mat r2(in_small_pad.w / 4, in_small_pad.h / 4, 32);
  173. // ncnn::Mat r3(in_small_pad.w / 8, in_small_pad.h / 8, 64);
  174. // ncnn::Mat r4(in_small_pad.w / 16, in_small_pad.h / 16, 128);
  175. r1.fill(0.f);
  176. r2.fill(0.f);
  177. r3.fill(0.f);
  178. r4.fill(0.f);
  179. ncnn::Extractor ex = rvm.create_extractor();
  180. ex.input("in0", in_pad);
  181. ex.input("in1", in_small_pad);
  182. ex.input("in2", r1);
  183. ex.input("in3", r2);
  184. ex.input("in4", r3);
  185. ex.input("in5", r4);
  186. ncnn::Mat out_fgr;
  187. ncnn::Mat out_pha;
  188. if (downsample)
  189. {
  190. if (refine_deep)
  191. {
  192. // downsample + refine deep
  193. ex.extract("out2", out_fgr);
  194. ex.extract("out3", out_pha);
  195. }
  196. else // if (refine_fast)
  197. {
  198. // downsample + refine fast
  199. ex.extract("out4", out_fgr);
  200. ex.extract("out5", out_pha);
  201. }
  202. }
  203. else
  204. {
  205. // no downsample
  206. ex.extract("out0", out_fgr);
  207. ex.extract("out1", out_pha);
  208. }
  209. ncnn::Mat out_seg;
  210. // segmentation
  211. ex.extract("out6", out_seg);
  212. // feats
  213. ex.extract("out7", r1);
  214. ex.extract("out8", r2);
  215. ex.extract("out9", r3);
  216. ex.extract("out10", r4);
  217. const float denorm_vals[3] = {255.f, 255.f, 255.f};
  218. out_fgr.substract_mean_normalize(0, denorm_vals);
  219. fgr.create(out_fgr.h, out_fgr.w, CV_8UC3);
  220. out_fgr.to_pixels(fgr.data, ncnn::Mat::PIXEL_RGB2BGR);
  221. out_pha.substract_mean_normalize(0, denorm_vals);
  222. pha.create(out_pha.h, out_pha.w, CV_8UC1);
  223. out_pha.to_pixels(pha.data, ncnn::Mat::PIXEL_GRAY);
  224. out_seg.substract_mean_normalize(0, denorm_vals);
  225. seg.create(in_pad.h, in_pad.w, CV_8UC1);
  226. out_seg.to_pixels_resize(seg.data, ncnn::Mat::PIXEL_GRAY, in_pad.w, in_pad.h);
  227. // cut letterbox pad
  228. fgr = fgr(cv::Rect(wpad / 2, hpad / 2, w, h));
  229. pha = pha(cv::Rect(wpad / 2, hpad / 2, w, h));
  230. seg = seg(cv::Rect(wpad / 2, hpad / 2, w, h));
  231. return 0;
  232. }
  233. static void draw_objects(const cv::Mat& bgr, const cv::Mat& fgr, const cv::Mat& pha, const cv::Mat& seg)
  234. {
  235. const int w = bgr.cols;
  236. const int h = bgr.rows;
  237. // composite
  238. cv::Mat comp(h, w, CV_8UC3);
  239. for (int y = 0; y < h; y++)
  240. {
  241. const uchar* pf = fgr.ptr<const uchar>(y);
  242. const uchar* pa = pha.ptr<const uchar>(y);
  243. uchar* p = comp.ptr<uchar>(y);
  244. for (int x = 0; x < w; x++)
  245. {
  246. const float alpha = pa[0] / 255.f;
  247. p[0] = cv::saturate_cast<uchar>(pf[0] * alpha + (1 - alpha) * 155);
  248. p[1] = cv::saturate_cast<uchar>(pf[1] * alpha + (1 - alpha) * 255);
  249. p[2] = cv::saturate_cast<uchar>(pf[2] * alpha + (1 - alpha) * 120);
  250. pf += 3;
  251. pa += 1;
  252. p += 3;
  253. }
  254. }
  255. // composite seg
  256. cv::Mat comp_seg(h, w, CV_8UC3);
  257. for (int y = 0; y < h; y++)
  258. {
  259. const uchar* pb = bgr.ptr<const uchar>(y);
  260. const uchar* ps = seg.ptr<const uchar>(y);
  261. uchar* p = comp_seg.ptr<uchar>(y);
  262. for (int x = 0; x < w; x++)
  263. {
  264. const float alpha = ps[0] / 255.f;
  265. p[0] = cv::saturate_cast<uchar>(pb[0] * alpha + (1 - alpha) * 155);
  266. p[1] = cv::saturate_cast<uchar>(pb[1] * alpha + (1 - alpha) * 255);
  267. p[2] = cv::saturate_cast<uchar>(pb[2] * alpha + (1 - alpha) * 120);
  268. pb += 3;
  269. ps += 1;
  270. p += 3;
  271. }
  272. }
  273. cv::imshow("comp", comp);
  274. cv::imshow("comp_seg", comp_seg);
  275. cv::waitKey(0);
  276. }
  277. int main(int argc, char** argv)
  278. {
  279. if (argc != 2)
  280. {
  281. fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
  282. return -1;
  283. }
  284. const char* imagepath = argv[1];
  285. cv::Mat m = cv::imread(imagepath, 1);
  286. if (m.empty())
  287. {
  288. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  289. return -1;
  290. }
  291. cv::Mat fgr;
  292. cv::Mat pha;
  293. cv::Mat seg;
  294. detect_rvm(m, fgr, pha, seg);
  295. draw_objects(m, fgr, pha, seg);
  296. return 0;
  297. }