You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. import time
  6. import warnings
  7. import sys
  8. import shutil
  9. import json
  10. sys.path.append("/tmp/code/code_test")
  11. import mmcv
  12. import torch
  13. from pycocotools.coco import COCO
  14. #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  15. from mmcv import Config, DictAction
  16. from mmcv.cnn import fuse_conv_bn
  17. from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
  18. from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
  19. wrap_fp16_model)
  20. from mmdet.apis import multi_gpu_test, single_gpu_test
  21. from mmdet.datasets import (build_dataloader, build_dataset,
  22. replace_ImageToTensor)
  23. from mmdet.models import build_detector
  24. # Copyright (c) OpenMMLab. All rights reserved.
  25. from functools import partial
  26. import numpy as np
  27. from sklearn.covariance import LedoitWolf
  28. from mmdet.core.export import build_model_from_cfg, preprocess_example_input
  29. from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
  30. from mmdet.apis import (async_inference_detector, inference_detector,
  31. init_detector, show_result_pyplot)
  32. import onnxruntime as ort
  33. import onnx
  34. print(f"onnxruntime device: {ort.get_device()}") # output: GPU
  35. print(f'ort avail providers: {ort.get_available_providers()}') # output: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  36. def cal_features(config_file, checkpoint_file, data_path_images, data_path_labels):
  37. data_coco = json.load(open(data_path_labels))
  38. data_name = data_coco["images"]
  39. model = init_detector(config_file, checkpoint_file, device='cuda:0')
  40. imgs_name = []
  41. for i in range(len(data_name)):
  42. imgs_name.append(osp.join(data_path_images, data_name[i]["file_name"]))
  43. print("before infer")
  44. index = 0
  45. num = len(imgs_name)
  46. results = []
  47. step = 1
  48. while index<num:
  49. index += step
  50. if index < num:
  51. _, results_tmp = inference_detector(model, imgs_name[index-step:index], feat=True)
  52. else:
  53. _, results_tmp = inference_detector(model, imgs_name[index-step:num], feat=True)
  54. results += results_tmp
  55. #print(len(results_tmp))
  56. #print(len(results))
  57. print("after infer")
  58. return results
  59. def cal_recall(config_file, checkpoint_file, data_path_images, data_path_labels):
  60. data_coco = json.load(open(data_path_labels))
  61. data_name = data_coco["images"]
  62. data_ann = data_coco['annotations']
  63. boxes = {}
  64. for res in data_ann:
  65. #print(res)
  66. img_id = res["image_id"]
  67. for i in range(len(data_name)):
  68. if img_id == data_name[i]["id"]:
  69. img_name = data_name[i]["file_name"]
  70. break
  71. bbox = res["bbox"]
  72. label = res["category_id"]
  73. bbox.append(int(label))
  74. if img_name in boxes.keys():
  75. boxes[img_name].append(bbox)
  76. else:
  77. boxes[img_name]=[]
  78. boxes[img_name].append(bbox)
  79. model = init_detector(config_file, checkpoint_file, device='cuda:0')
  80. imgs_labels = []
  81. imgs_name = []
  82. num_ng = 0
  83. for i in range(len(data_name)):
  84. res_label = 0
  85. if data_name[i]["file_name"] in boxes.keys():
  86. res_label = 1
  87. num_ng += 1
  88. imgs_labels.append(res_label)
  89. imgs_name.append(osp.join(data_path_images, data_name[i]["file_name"]))
  90. num_ok = len(data_name)-num_ng
  91. print(len(imgs_labels), num_ok, num_ng)
  92. print("before infer")
  93. index = 0
  94. num = len(imgs_name)
  95. results = []
  96. step = 1
  97. while index<num:
  98. index += step
  99. if index < num:
  100. results_tmp = inference_detector(model, imgs_name[index-step:index])
  101. else:
  102. results_tmp = inference_detector(model, imgs_name[index-step:num])
  103. results += results_tmp
  104. #print(len(results))
  105. print("after infer")
  106. #score_thrs = [0.01, 0.011, 0.012, 0.013, 0.014, 0.015, 0.016, 0.017, 0.018, 0.019, 0.02]
  107. recall_thrs = []
  108. for score_thr in np.arange(0.01, 0.5, 0.01):
  109. imgs_results = []
  110. for result in results:
  111. res_predict = 0
  112. #print(len(result))
  113. for i in result:
  114. #print(i.shape)
  115. for j in range(i.shape[0]):
  116. if i[j, 4]>score_thr:
  117. res_predict = 1
  118. imgs_results.append(res_predict)
  119. count_ng = 0
  120. count_ok = 0
  121. for i in range(len(imgs_labels)):
  122. if imgs_labels[i]==0 and imgs_results[i]==0:
  123. count_ok += 1
  124. if imgs_labels[i]==1 and imgs_results[i]==1:
  125. count_ng += 1
  126. '''if imgs_labels[i]==1 and imgs_results_1[i]==0:
  127. print(imgs_name[i])'''
  128. recall_thr = {"score_thr":score_thr, "recall(ok)":count_ok/(num_ok+0.00000001), "recall(ng)":count_ng/(num_ng+0.00000001)}
  129. recall_thrs.append(recall_thr)
  130. return recall_thrs
  131. def pytorch2onnx(model,
  132. input_img,
  133. input_shape,
  134. normalize_cfg,
  135. opset_version=11,
  136. show=False,
  137. output_file='model.onnx',
  138. verify=True,
  139. test_img=None,
  140. do_simplify=False,
  141. dynamic_export=True,
  142. skip_postprocess=False):
  143. input_config = {
  144. 'input_shape': input_shape,
  145. 'input_path': input_img,
  146. 'normalize_cfg': normalize_cfg
  147. }
  148. # prepare input
  149. one_img, one_meta = preprocess_example_input(input_config)
  150. img_list, img_meta_list = [one_img], [[one_meta]]
  151. if skip_postprocess:
  152. warnings.warn('Not all models support export onnx without post '
  153. 'process, especially two stage detectors!')
  154. model.forward = model.forward_dummy
  155. torch.onnx.export(
  156. model,
  157. one_img,
  158. output_file,
  159. input_names=['input'],
  160. export_params=True,
  161. keep_initializers_as_inputs=True,
  162. do_constant_folding=True,
  163. verbose=show,
  164. opset_version=opset_version)
  165. print(f'Successfully exported ONNX model without '
  166. f'post process: {output_file}')
  167. return
  168. # replace original forward function
  169. origin_forward = model.forward
  170. model.forward = partial(
  171. model.forward,
  172. img_metas=img_meta_list,
  173. return_loss=False,
  174. rescale=False)
  175. output_names = ['dets', 'labels', 'feature', 'entropy', 'learning_loss']
  176. if model.with_mask:
  177. output_names.append('masks')
  178. input_name = 'input'
  179. dynamic_axes = None
  180. if dynamic_export:
  181. dynamic_axes = {
  182. input_name: {
  183. 0: 'batch',
  184. 2: 'height',
  185. 3: 'width'
  186. },
  187. 'dets': {
  188. 0: 'batch',
  189. 1: 'num_dets',
  190. },
  191. 'labels': {
  192. 0: 'batch',
  193. 1: 'num_dets',
  194. },
  195. 'feature': {
  196. 0: 'batch',
  197. 1: 'feat_dim',
  198. },
  199. 'entropy': {
  200. 0: 'batch',
  201. 1: '1',
  202. },
  203. 'learning_loss': {
  204. 0: 'batch',
  205. 1: '1',
  206. },
  207. }
  208. if model.with_mask:
  209. dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}
  210. torch.onnx.export(
  211. model,
  212. img_list,
  213. output_file,
  214. input_names=[input_name],
  215. output_names=output_names,
  216. export_params=True,
  217. keep_initializers_as_inputs=True,
  218. do_constant_folding=True,
  219. verbose=show,
  220. opset_version=opset_version,
  221. dynamic_axes=dynamic_axes)
  222. model.forward = origin_forward
  223. # get the custom op path
  224. ort_custom_op_path = ''
  225. try:
  226. from mmcv.ops import get_onnxruntime_op_path
  227. ort_custom_op_path = get_onnxruntime_op_path()
  228. except (ImportError, ModuleNotFoundError):
  229. warnings.warn('If input model has custom op from mmcv, \
  230. you may have to build mmcv with ONNXRuntime from source.')
  231. if do_simplify:
  232. import onnxsim
  233. from mmdet import digit_version
  234. min_required_version = '0.3.0'
  235. assert digit_version(onnxsim.__version__) >= digit_version(
  236. min_required_version
  237. ), f'Requires to install onnx-simplify>={min_required_version}'
  238. input_dic = {'input': img_list[0].detach().cpu().numpy()}
  239. model_opt, check_ok = onnxsim.simplify(
  240. output_file,
  241. input_data=input_dic,
  242. custom_lib=ort_custom_op_path,
  243. dynamic_input_shape=dynamic_export)
  244. if check_ok:
  245. onnx.save(model_opt, output_file)
  246. print(f'Successfully simplified ONNX model: {output_file}')
  247. else:
  248. warnings.warn('Failed to simplify ONNX model.')
  249. print(f'Successfully exported ONNX model: {output_file}')
  250. if verify:
  251. # check by onnx
  252. onnx_model = onnx.load(output_file)
  253. onnx.checker.check_model(onnx_model)
  254. # wrap onnx model
  255. onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
  256. if dynamic_export:
  257. # scale up to test dynamic shape
  258. h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
  259. h, w = min(1344, h), min(1344, w)
  260. input_config['input_shape'] = (1, 3, h, w)
  261. if test_img is None:
  262. input_config['input_path'] = input_img
  263. # prepare input once again
  264. one_img, one_meta = preprocess_example_input(input_config)
  265. img_list, img_meta_list = [one_img], [[one_meta]]
  266. # get pytorch output
  267. with torch.no_grad():
  268. pytorch_results = model(
  269. img_list,
  270. img_metas=img_meta_list,
  271. return_loss=False,
  272. rescale=True)[0]
  273. img_list = [_.cuda().contiguous() for _ in img_list]
  274. if dynamic_export:
  275. img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
  276. img_meta_list = img_meta_list * 2
  277. # get onnx output
  278. onnx_results = onnx_model(
  279. img_list, img_metas=img_meta_list, return_loss=False)[0]
  280. # visualize predictions
  281. score_thr = 0.3
  282. if show:
  283. out_file_ort, out_file_pt = None, None
  284. else:
  285. out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'
  286. show_img = one_meta['show_img']
  287. model.show_result(
  288. show_img,
  289. pytorch_results,
  290. score_thr=score_thr,
  291. show=True,
  292. win_name='PyTorch',
  293. out_file=out_file_pt)
  294. onnx_model.show_result(
  295. show_img,
  296. onnx_results,
  297. score_thr=score_thr,
  298. show=True,
  299. win_name='ONNXRuntime',
  300. out_file=out_file_ort)
  301. # compare a part of result
  302. '''print(input_config['input_shape'])
  303. print(one_img)
  304. print(len(onnx_results))
  305. print(len(pytorch_results))
  306. print(onnx_results)
  307. print(pytorch_results)'''
  308. for i in range(len(onnx_results)):
  309. print(onnx_results[i].shape)
  310. print("***************")
  311. for i in range(len(pytorch_results)):
  312. print(pytorch_results[i].shape)
  313. if model.with_mask:
  314. compare_pairs = list(zip(onnx_results, pytorch_results))
  315. else:
  316. compare_pairs = [(onnx_results, pytorch_results)]
  317. err_msg = 'The numerical values are different between Pytorch' + \
  318. ' and ONNX, but it does not necessarily mean the' + \
  319. ' exported ONNX model is problematic.'
  320. # check the numerical value
  321. for onnx_res, pytorch_res in compare_pairs:
  322. for o_res, p_res in zip(onnx_res, pytorch_res):
  323. np.testing.assert_allclose(
  324. o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
  325. print('The numerical values are the same between Pytorch and ONNX')
  326. def parse_normalize_cfg(test_pipeline):
  327. transforms = None
  328. for pipeline in test_pipeline:
  329. if 'transforms' in pipeline:
  330. transforms = pipeline['transforms']
  331. break
  332. assert transforms is not None, 'Failed to find `transforms`'
  333. norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
  334. assert len(norm_config_li) == 1, '`norm_config` should only have one'
  335. norm_config = norm_config_li[0]
  336. return norm_config
  337. def parse_args():
  338. parser = argparse.ArgumentParser(
  339. description='MMDet test (and eval) a model')
  340. parser.add_argument('--train-work-dir', default='/model', help='checkpoint file')
  341. parser.add_argument(
  342. '--work-dir',
  343. default='/result',
  344. help='the directory to save the file containing evaluation metrics')
  345. parser.add_argument(
  346. '--shape',
  347. help='infer image shape')
  348. parser.add_argument(
  349. '--data-path', default='/dataset', help='dataset path')
  350. parser.add_argument('--out', default='/result', help='output result file in pickle format')
  351. parser.add_argument(
  352. '--fuse-conv-bn',
  353. action='store_true',
  354. help='Whether to fuse conv and bn, this will slightly increase'
  355. 'the inference speed')
  356. parser.add_argument(
  357. '--format-only',
  358. action='store_true',
  359. help='Format the output results without perform evaluation. It is'
  360. 'useful when you want to format the result to a specific format and '
  361. 'submit it to the test server')
  362. parser.add_argument(
  363. '--eval',
  364. type=str,
  365. default='bbox',
  366. nargs='+',
  367. help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
  368. ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
  369. parser.add_argument('--show', action='store_true', help='show results')
  370. parser.add_argument(
  371. '--show-dir', help='directory where painted images will be saved')
  372. parser.add_argument(
  373. '--show-score-thr',
  374. type=float,
  375. default=0.3,
  376. help='score threshold (default: 0.3)')
  377. parser.add_argument(
  378. '--gpu-collect',
  379. action='store_true',
  380. help='whether to use gpu to collect results.')
  381. parser.add_argument(
  382. '--tmpdir',
  383. help='tmp directory used for collecting results from multiple '
  384. 'workers, available when gpu-collect is not specified')
  385. parser.add_argument(
  386. '--cfg-options',
  387. nargs='+',
  388. action=DictAction,
  389. help='override some settings in the used config, the key-value pair '
  390. 'in xxx=yyy format will be merged into config file. If the value to '
  391. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  392. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  393. 'Note that the quotation marks are necessary and that no white space '
  394. 'is allowed.')
  395. parser.add_argument(
  396. '--options',
  397. nargs='+',
  398. action=DictAction,
  399. help='custom options for evaluation, the key-value pair in xxx=yyy '
  400. 'format will be kwargs for dataset.evaluate() function (deprecate), '
  401. 'change to --eval-options instead.')
  402. parser.add_argument(
  403. '--eval-options',
  404. nargs='+',
  405. action=DictAction,
  406. help='custom options for evaluation, the key-value pair in xxx=yyy '
  407. 'format will be kwargs for dataset.evaluate() function')
  408. parser.add_argument(
  409. '--launcher',
  410. choices=['none', 'pytorch', 'slurm', 'mpi'],
  411. default='none',
  412. help='job launcher')
  413. parser.add_argument('--local_rank', type=int, default=0)
  414. args = parser.parse_args()
  415. if 'LOCAL_RANK' not in os.environ:
  416. os.environ['LOCAL_RANK'] = str(args.local_rank)
  417. if args.options and args.eval_options:
  418. raise ValueError(
  419. '--options and --eval-options cannot be both '
  420. 'specified, --options is deprecated in favor of --eval-options')
  421. if args.options:
  422. warnings.warn('--options is deprecated in favor of --eval-options')
  423. args.eval_options = args.options
  424. return args
  425. def main():
  426. args = parse_args()
  427. assert args.out or args.eval or args.format_only or args.show \
  428. or args.show_dir, \
  429. ('Please specify at least one operation (save/eval/format/show the '
  430. 'results / save the results) with the argument "--out", "--eval"'
  431. ', "--format-only", "--show" or "--show-dir"')
  432. if args.eval and args.format_only:
  433. raise ValueError('--eval and --format_only cannot be both specified')
  434. if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
  435. raise ValueError('The output file must be a pkl file.')
  436. cfg = Config.fromfile(osp.join(args.train_work_dir, 'config.py'))
  437. if args.cfg_options is not None:
  438. cfg.merge_from_dict(args.cfg_options)
  439. # import modules from string list.
  440. if cfg.get('custom_imports', None):
  441. from mmcv.utils import import_modules_from_strings
  442. import_modules_from_strings(**cfg['custom_imports'])
  443. # set cudnn_benchmark
  444. if cfg.get('cudnn_benchmark', False):
  445. torch.backends.cudnn.benchmark = True
  446. cfg.model.pretrained = None
  447. if cfg.model.get('neck'):
  448. if isinstance(cfg.model.neck, list):
  449. for neck_cfg in cfg.model.neck:
  450. if neck_cfg.get('rfp_backbone'):
  451. if neck_cfg.rfp_backbone.get('pretrained'):
  452. neck_cfg.rfp_backbone.pretrained = None
  453. elif cfg.model.neck.get('rfp_backbone'):
  454. if cfg.model.neck.rfp_backbone.get('pretrained'):
  455. cfg.model.neck.rfp_backbone.pretrained = None
  456. # in case the test dataset is concatenated
  457. if isinstance(cfg.data.test, dict):
  458. cfg.data.test.test_mode = True
  459. samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
  460. if samples_per_gpu > 1:
  461. # Replace 'ImageToTensor' to 'DefaultFormatBundle'
  462. cfg.data.test.pipeline = replace_ImageToTensor(
  463. cfg.data.test.pipeline)
  464. elif isinstance(cfg.data.test, list):
  465. for ds_cfg in cfg.data.test:
  466. ds_cfg.test_mode = True
  467. samples_per_gpu = max(
  468. [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
  469. if samples_per_gpu > 1:
  470. for ds_cfg in cfg.data.test:
  471. ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
  472. # init distributed env first, since logger depends on the dist info.
  473. if args.launcher == 'none':
  474. distributed = False
  475. else:
  476. distributed = True
  477. init_dist(args.launcher, **cfg.dist_params)
  478. rank, _ = get_dist_info()
  479. # allows not to create
  480. if args.work_dir is not None and rank == 0:
  481. mmcv.mkdir_or_exist(args.work_dir)
  482. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  483. #json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
  484. if args.data_path is not None:
  485. coco_config=COCO(os.path.join(args.data_path,"annotations/instances_annotations.json"))
  486. cfg.data.test.img_prefix = os.path.join(args.data_path,"images")
  487. cfg.data.test.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
  488. cfg.classes = ()
  489. for cat in coco_config.cats.values():
  490. cfg.classes = cfg.classes + tuple([cat['name']])
  491. cfg.data.test.classes = cfg.classes
  492. # build the dataloader
  493. samples_per_gpu = 1
  494. #print(samples_per_gpu)
  495. dataset = build_dataset(cfg.data.test)
  496. data_loader = build_dataloader(
  497. dataset,
  498. samples_per_gpu=samples_per_gpu,
  499. workers_per_gpu=cfg.data.workers_per_gpu,
  500. dist=distributed,
  501. shuffle=False)
  502. eval_results = []
  503. best_eval_result = {'checkpoint':'epoch_1.pth','AUC':0, 'bbox_mAP_50':0}
  504. checkpoint_files = os.listdir(args.train_work_dir)
  505. for checkpoint_file in checkpoint_files:
  506. if not checkpoint_file.endswith('pth'):
  507. continue
  508. # build the model and load checkpoint
  509. cfg.model.train_cfg = None
  510. model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
  511. fp16_cfg = cfg.get('fp16', None)
  512. if fp16_cfg is not None:
  513. wrap_fp16_model(model)
  514. checkpoint = load_checkpoint(model, osp.join(args.train_work_dir, checkpoint_file), map_location='cpu')
  515. if args.fuse_conv_bn:
  516. model = fuse_conv_bn(model)
  517. # old versions did not save class info in checkpoints, this walkaround is
  518. # for backward compatibility
  519. if 'CLASSES' in checkpoint.get('meta', {}):
  520. model.CLASSES = checkpoint['meta']['CLASSES']
  521. else:
  522. model.CLASSES = dataset.CLASSES
  523. if not distributed:
  524. model = MMDataParallel(model, device_ids=[0])
  525. outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
  526. args.show_score_thr)
  527. else:
  528. model = MMDistributedDataParallel(
  529. model.cuda(),
  530. device_ids=[torch.cuda.current_device()],
  531. broadcast_buffers=False)
  532. outputs = multi_gpu_test(model, data_loader, args.tmpdir,
  533. args.gpu_collect)
  534. rank, _ = get_dist_info()
  535. if rank == 0:
  536. if args.out:
  537. print(f'\nwriting results to {args.out}')
  538. mmcv.dump(outputs, args.out)
  539. kwargs = {} if args.eval_options is None else args.eval_options
  540. if args.format_only:
  541. dataset.format_results(outputs, **kwargs)
  542. if args.eval:
  543. eval_kwargs = cfg.get('evaluation', {}).copy()
  544. # hard-code way to remove EvalHook args
  545. for key in [
  546. 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
  547. 'rule'
  548. ]:
  549. eval_kwargs.pop(key, None)
  550. eval_kwargs.update(dict(metric=args.eval, **kwargs))
  551. metric = dataset.evaluate(outputs, **eval_kwargs)
  552. #metric = dataset.evaluate(outputs, iou_thrs=[0.5], classwise=True)
  553. print(metric)
  554. print(metric['AUC'])
  555. print(metric['bbox_mAP_50'])
  556. eval_result = {'checkpoint':checkpoint_file, 'AUC':metric['AUC'], 'bbox_mAP_50':metric['bbox_mAP_50']}
  557. eval_results.append(eval_result)
  558. if eval_result['AUC'] + eval_result['bbox_mAP_50'] > best_eval_result['AUC'] + best_eval_result['bbox_mAP_50']:
  559. best_eval_result = eval_result
  560. '''metric_dict = dict(config=args.config, metric=metric)
  561. if args.work_dir is not None and rank == 0:
  562. mmcv.dump(metric_dict, json_file)'''
  563. print(eval_results)
  564. print(best_eval_result)
  565. if args.shape is None:
  566. img_scale = cfg.test_pipeline[1]['img_scale'][0]
  567. print(img_scale)
  568. input_shape = (1, 3, img_scale[1], img_scale[0])
  569. elif len(args.shape) == 1:
  570. input_shape = (1, 3, args.shape[0], args.shape[0])
  571. elif len(args.shape) == 2:
  572. input_shape = (1, 3) + tuple(args.shape)
  573. else:
  574. raise ValueError('invalid input shape')
  575. '''if os.path.exists(osp.abspath(osp.join(args.work_dir, "infer/"))):
  576. shutil.rmtree(osp.abspath(osp.join(args.work_dir, "infer/")))'''
  577. # create onnx dir
  578. onnx_path = osp.join(args.work_dir, 'infer')
  579. mmcv.mkdir_or_exist(onnx_path)
  580. #shutil.copytree(osp.abspath(osp.join(osp.dirname(__file__),'../../../infer/')), onnx_path)
  581. # build the model and load checkpoint
  582. model = build_model_from_cfg(osp.join(args.train_work_dir, 'config.py'), osp.join(args.train_work_dir, best_eval_result['checkpoint']))
  583. input_img = osp.join(osp.dirname(__file__), 'demo.jpg')
  584. normalize_cfg = parse_normalize_cfg(cfg.test_pipeline)
  585. # convert model to onnx file
  586. pytorch2onnx(
  587. model,
  588. input_img,
  589. input_shape,
  590. normalize_cfg,
  591. output_file=osp.join(onnx_path,'model.onnx'),
  592. test_img=input_img)
  593. recall_thrs = cal_recall(osp.join(args.train_work_dir, 'config.py'), osp.join(args.train_work_dir, best_eval_result['checkpoint']), os.path.join(args.data_path,"images"), os.path.join(args.data_path,"annotations/instances_annotations.json"))
  594. best_eval_result['recall'] = recall_thrs
  595. print(best_eval_result)
  596. json_file = osp.join(args.work_dir, f'eval_result.json')
  597. mmcv.dump(best_eval_result, json_file)
  598. train_feats = cal_features(osp.join(args.train_work_dir, 'config.py'), osp.join(args.train_work_dir, best_eval_result['checkpoint']), os.path.join(args.data_path,"images"), os.path.join(args.data_path,"annotations/instances_annotations.json"))
  599. train_feats = np.array(train_feats)
  600. print(train_feats.shape)
  601. train_mean = np.mean(train_feats, axis=0)
  602. train_cov = LedoitWolf().fit(train_feats).covariance_
  603. train_cov_inv = np.linalg.pinv(train_cov)
  604. print(train_mean.shape, train_cov.shape, train_cov_inv.shape)
  605. shutil.copy(osp.join(args.train_work_dir, "config.py"), osp.join(args.work_dir, "infer/config.py"))
  606. shutil.copy(osp.join(args.train_work_dir, best_eval_result['checkpoint']), osp.join(args.work_dir, "infer/"+best_eval_result['checkpoint']))
  607. shutil.copytree(osp.abspath(osp.join(osp.dirname(__file__),'../../transformer/')), osp.join(args.work_dir, "infer/transformer"))
  608. class_name_file = open(osp.join(args.work_dir, "infer/class_names.txt"), 'w')
  609. for name in cfg.classes:
  610. class_name_file.write(name+'\n')
  611. print(osp.join(args.work_dir, "infer/class_names.txt"))
  612. np.savez(osp.join(args.work_dir, "infer/train_feature.npy"),train_mean=train_mean, train_cov=train_cov, train_cov_inv=train_cov_inv)
  613. print(osp.join(args.work_dir, "infer/train_feature.npy"))
  614. shutil.copy(osp.abspath(osp.join(osp.dirname(__file__),'serve_desc.yaml')), osp.join(args.work_dir, "infer/serve_desc.yaml"))
  615. shutil.copy(osp.abspath(osp.join(osp.dirname(__file__),'ext.proto')), osp.join(args.work_dir, "infer/transformer/ext.proto"))
  616. if __name__ == '__main__':
  617. main()

No Description

Contributors (3)