You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 21 kB

2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import copy
  4. import os
  5. import os.path as osp
  6. import time
  7. import warnings
  8. import shutil
  9. import sys
  10. path = os.path.dirname(os.path.dirname(__file__))
  11. print(path)
  12. sys.path.append("/tmp/code/code_test")
  13. os.system("pip install --no-cache-dir onnx==1.11.0 onnxruntime==1.11.1 protobuf==3.20.0")
  14. #os.environ['RANK'] = "0"
  15. #os.environ['WORLD_SIZE'] = "8"
  16. #os.environ['MASTER_ADDR'] = "localhost"
  17. #os.environ['MASTER_PORT'] = "1234"
  18. import mmcv
  19. import torch
  20. #os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
  21. from mmcv import Config, DictAction
  22. from mmcv.runner import get_dist_info, init_dist
  23. from mmcv.utils import get_git_hash
  24. from pycocotools.coco import COCO
  25. from mmdet import __version__
  26. from mmdet.apis import init_random_seed, set_random_seed, train_detector
  27. from mmdet.datasets import build_dataset
  28. from mmdet.models import build_detector
  29. from mmdet.utils import collect_env, get_root_logger
  30. # Copyright (c) OpenMMLab. All rights reserved.
  31. from functools import partial
  32. import numpy as np
  33. from sklearn.covariance import LedoitWolf
  34. from mmdet.core.export import build_model_from_cfg, preprocess_example_input
  35. from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
  36. from mmdet.apis import (async_inference_detector, inference_detector,
  37. init_detector, show_result_pyplot)
  38. import onnxruntime as ort
  39. import onnx
  40. print(f"onnxruntime device: {ort.get_device()}") # output: GPU
  41. print(f'ort avail providers: {ort.get_available_providers()}') # output: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  42. def parse_normalize_cfg(test_pipeline):
  43. transforms = None
  44. for pipeline in test_pipeline:
  45. if 'transforms' in pipeline:
  46. transforms = pipeline['transforms']
  47. break
  48. assert transforms is not None, 'Failed to find `transforms`'
  49. norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
  50. assert len(norm_config_li) == 1, '`norm_config` should only have one'
  51. norm_config = norm_config_li[0]
  52. return norm_config
  53. def pytorch2onnx(model,
  54. input_img,
  55. input_shape,
  56. normalize_cfg,
  57. opset_version=11,
  58. show=False,
  59. output_file='model.onnx',
  60. verify=False,
  61. test_img=None,
  62. do_simplify=False,
  63. dynamic_export=True,
  64. skip_postprocess=False):
  65. input_config = {
  66. 'input_shape': input_shape,
  67. 'input_path': input_img,
  68. 'normalize_cfg': normalize_cfg
  69. }
  70. # prepare input
  71. one_img, one_meta = preprocess_example_input(input_config)
  72. img_list, img_meta_list = [one_img], [[one_meta]]
  73. if skip_postprocess:
  74. warnings.warn('Not all models support export onnx without post '
  75. 'process, especially two stage detectors!')
  76. model.forward = model.forward_dummy
  77. torch.onnx.export(
  78. model,
  79. one_img,
  80. output_file,
  81. input_names=['input'],
  82. export_params=True,
  83. keep_initializers_as_inputs=True,
  84. do_constant_folding=True,
  85. verbose=show,
  86. opset_version=opset_version)
  87. print(f'Successfully exported ONNX model without '
  88. f'post process: {output_file}')
  89. return
  90. # replace original forward function
  91. origin_forward = model.forward
  92. model.forward = partial(
  93. model.forward,
  94. img_metas=img_meta_list,
  95. return_loss=False,
  96. rescale=False)
  97. output_names = ['dets', 'labels', 'feature', 'entropy', 'learning_loss']
  98. if model.with_mask:
  99. output_names.append('masks')
  100. input_name = 'input'
  101. dynamic_axes = None
  102. if dynamic_export:
  103. dynamic_axes = {
  104. input_name: {
  105. 0: 'batch',
  106. 2: 'height',
  107. 3: 'width'
  108. },
  109. 'dets': {
  110. 0: 'batch',
  111. 1: 'num_dets',
  112. },
  113. 'labels': {
  114. 0: 'batch',
  115. 1: 'num_dets',
  116. },
  117. 'feature': {
  118. 0: 'batch',
  119. 1: 'feat_dim',
  120. },
  121. 'entropy': {
  122. 0: 'batch',
  123. 1: '1',
  124. },
  125. 'learning_loss': {
  126. 0: 'batch',
  127. 1: '1',
  128. },
  129. }
  130. if model.with_mask:
  131. dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}
  132. torch.onnx.export(
  133. model,
  134. img_list,
  135. output_file,
  136. input_names=[input_name],
  137. output_names=output_names,
  138. export_params=True,
  139. keep_initializers_as_inputs=True,
  140. do_constant_folding=True,
  141. verbose=show,
  142. opset_version=opset_version,
  143. dynamic_axes=dynamic_axes)
  144. model.forward = origin_forward
  145. # get the custom op path
  146. ort_custom_op_path = ''
  147. try:
  148. from mmcv.ops import get_onnxruntime_op_path
  149. ort_custom_op_path = get_onnxruntime_op_path()
  150. except (ImportError, ModuleNotFoundError):
  151. warnings.warn('If input model has custom op from mmcv, \
  152. you may have to build mmcv with ONNXRuntime from source.')
  153. if do_simplify:
  154. import onnxsim
  155. from mmdet import digit_version
  156. min_required_version = '0.3.0'
  157. assert digit_version(onnxsim.__version__) >= digit_version(
  158. min_required_version
  159. ), f'Requires to install onnx-simplify>={min_required_version}'
  160. input_dic = {'input': img_list[0].detach().cpu().numpy()}
  161. model_opt, check_ok = onnxsim.simplify(
  162. output_file,
  163. input_data=input_dic,
  164. custom_lib=ort_custom_op_path,
  165. dynamic_input_shape=dynamic_export)
  166. if check_ok:
  167. onnx.save(model_opt, output_file)
  168. print(f'Successfully simplified ONNX model: {output_file}')
  169. else:
  170. warnings.warn('Failed to simplify ONNX model.')
  171. print(f'Successfully exported ONNX model: {output_file}')
  172. if verify:
  173. # check by onnx
  174. onnx_model = onnx.load(output_file)
  175. onnx.checker.check_model(onnx_model)
  176. # wrap onnx model
  177. onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
  178. if dynamic_export:
  179. # scale up to test dynamic shape
  180. h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
  181. h, w = min(1344, h), min(1344, w)
  182. input_config['input_shape'] = (1, 3, h, w)
  183. if test_img is None:
  184. input_config['input_path'] = input_img
  185. # prepare input once again
  186. one_img, one_meta = preprocess_example_input(input_config)
  187. img_list, img_meta_list = [one_img], [[one_meta]]
  188. # get pytorch output
  189. with torch.no_grad():
  190. pytorch_results = model(
  191. img_list,
  192. img_metas=img_meta_list,
  193. return_loss=False,
  194. rescale=True)[0]
  195. img_list = [_.cuda().contiguous() for _ in img_list]
  196. if dynamic_export:
  197. img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
  198. img_meta_list = img_meta_list * 2
  199. # get onnx output
  200. onnx_results = onnx_model(
  201. img_list, img_metas=img_meta_list, return_loss=False)[0]
  202. # visualize predictions
  203. score_thr = 0.3
  204. if show:
  205. out_file_ort, out_file_pt = None, None
  206. else:
  207. out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'
  208. show_img = one_meta['show_img']
  209. model.show_result(
  210. show_img,
  211. pytorch_results,
  212. score_thr=score_thr,
  213. show=True,
  214. win_name='PyTorch',
  215. out_file=out_file_pt)
  216. onnx_model.show_result(
  217. show_img,
  218. onnx_results,
  219. score_thr=score_thr,
  220. show=True,
  221. win_name='ONNXRuntime',
  222. out_file=out_file_ort)
  223. # compare a part of result
  224. '''print(input_config['input_shape'])
  225. print(one_img)
  226. print(len(onnx_results))
  227. print(len(pytorch_results))
  228. print(onnx_results)
  229. print(pytorch_results)'''
  230. for i in range(len(onnx_results)):
  231. print(onnx_results[i].shape)
  232. print("***************")
  233. for i in range(len(pytorch_results)):
  234. print(pytorch_results[i].shape)
  235. if model.with_mask:
  236. compare_pairs = list(zip(onnx_results, pytorch_results))
  237. else:
  238. compare_pairs = [(onnx_results, pytorch_results)]
  239. err_msg = 'The numerical values are different between Pytorch' + \
  240. ' and ONNX, but it does not necessarily mean the' + \
  241. ' exported ONNX model is problematic.'
  242. # check the numerical value
  243. for onnx_res, pytorch_res in compare_pairs:
  244. for o_res, p_res in zip(onnx_res, pytorch_res):
  245. np.testing.assert_allclose(
  246. o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
  247. print('The numerical values are the same between Pytorch and ONNX')
  248. def parse_args():
  249. parser = argparse.ArgumentParser(description='Train a detector')
  250. parser.add_argument('--config', default='/tmp/code/code_test/configs/AD_mlops/AD_mlops_test18.py', help='train config file path')
  251. parser.add_argument('--work-dir',default='/tmp/output', help='the dir to save logs and models')
  252. parser.add_argument(
  253. '--resume-from', help='the checkpoint file to resume from')
  254. parser.add_argument(
  255. '--no-validate',
  256. action='store_true',
  257. help='whether not to evaluate the checkpoint during training')
  258. parser.add_argument(
  259. '--shape',
  260. help='infer image shape')
  261. group_gpus = parser.add_mutually_exclusive_group()
  262. group_gpus.add_argument(
  263. '--gpus',
  264. type=int,
  265. help='number of gpus to use '
  266. '(only applicable to non-distributed training)')
  267. group_gpus.add_argument(
  268. '--gpu-ids',
  269. type=int,
  270. nargs='+',
  271. help='ids of gpus to use '
  272. '(only applicable to non-distributed training)')
  273. parser.add_argument('--seed', type=int, default=None, help='random seed')
  274. parser.add_argument(
  275. '--deterministic',
  276. action='store_true',
  277. help='whether to set deterministic options for CUDNN backend.')
  278. parser.add_argument(
  279. '--options',
  280. nargs='+',
  281. action=DictAction,
  282. help='override some settings in the used config, the key-value pair '
  283. 'in xxx=yyy format will be merged into config file (deprecate), '
  284. 'change to --cfg-options instead.')
  285. parser.add_argument(
  286. '--cfg-options',
  287. nargs='+',
  288. action=DictAction,
  289. help='override some settings in the used config, the key-value pair '
  290. 'in xxx=yyy format will be merged into config file. If the value to '
  291. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  292. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  293. 'Note that the quotation marks are necessary and that no white space '
  294. 'is allowed.')
  295. parser.add_argument(
  296. '--launcher',
  297. choices=['none', 'pytorch', 'slurm', 'mpi'],
  298. default='none',
  299. help='job launcher')
  300. parser.add_argument('--local_rank', type=int, default=0)
  301. parser.add_argument(
  302. '--data-path', default='/tmp/dataset', help='dataset path')
  303. parser.add_argument(
  304. '--batchsize',
  305. type=int,
  306. default=8,
  307. help='training batch size')
  308. parser.add_argument(
  309. '--epoch',
  310. type=int,
  311. default=2,
  312. help='training epoch')
  313. parser.add_argument(
  314. '--warmup_iters',
  315. type=int,
  316. default=500,
  317. help='training warmup_iters')
  318. parser.add_argument(
  319. '--lr',
  320. type=float,
  321. default=0.001,
  322. help='learning rate')
  323. parser.add_argument('--train_image_size',
  324. type=list,
  325. default=[(100, 100)],
  326. help='train image size')
  327. parser.add_argument('--test_image_size',
  328. type=list,
  329. default=[(100, 100)],
  330. help='test image size')
  331. args = parser.parse_args()
  332. if 'LOCAL_RANK' not in os.environ:
  333. os.environ['LOCAL_RANK'] = str(args.local_rank)
  334. if args.options and args.cfg_options:
  335. raise ValueError(
  336. '--options and --cfg-options cannot be both '
  337. 'specified, --options is deprecated in favor of --cfg-options')
  338. if args.options:
  339. warnings.warn('--options is deprecated in favor of --cfg-options')
  340. args.cfg_options = args.options
  341. return args
  342. def main():
  343. args = parse_args()
  344. cfg = Config.fromfile(args.config)
  345. if args.cfg_options is not None:
  346. cfg.merge_from_dict(args.cfg_options)
  347. # import modules from string list.
  348. if cfg.get('custom_imports', None):
  349. from mmcv.utils import import_modules_from_strings
  350. import_modules_from_strings(**cfg['custom_imports'])
  351. # set cudnn_benchmark
  352. if cfg.get('cudnn_benchmark', False):
  353. torch.backends.cudnn.benchmark = True
  354. if args.batchsize is not None:
  355. cfg.data.samples_per_gpu = args.batchsize
  356. if args.epoch is not None:
  357. cfg.runner.max_epochs = args.epoch
  358. if args.warmup_iters is not None:
  359. cfg.lr_config.warmup_iters = args.warmup_iters
  360. if args.lr is not None:
  361. cfg.optimizer.lr = args.lr
  362. '''if args.train_image_size is not None:
  363. cfg.train_pipeline[2].img_scale = args.train_image_size
  364. cfg.data.train.dataset.pipeline[2].img_scale = args.train_image_size'''
  365. if args.test_image_size is not None:
  366. cfg.test_pipeline[1].img_scale = args.test_image_size
  367. cfg.data.val.pipeline[1].img_scale = args.test_image_size
  368. cfg.data.test.pipeline[1].img_scale = args.test_image_size
  369. #if on platform, change the classnum fit the user define dataset
  370. if args.data_path is not None:
  371. coco_config=COCO(os.path.join(args.data_path,"annotations/instances_annotations.json"))
  372. cfg.data.train.img_prefix = os.path.join(args.data_path,"images")
  373. cfg.data.train.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
  374. cfg.data.val.img_prefix = os.path.join(args.data_path,"images")
  375. cfg.data.val.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
  376. cfg.data.test.img_prefix = os.path.join(args.data_path,"images")
  377. cfg.data.test.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
  378. cfg.classes = ()
  379. for cat in coco_config.cats.values():
  380. cfg.classes = cfg.classes + tuple([cat['name']])
  381. cfg.data.train.classes = cfg.classes
  382. cfg.data.val.classes = cfg.classes
  383. cfg.data.test.classes = cfg.classes
  384. #some model will RepeatDataset to speed up training, make sure all dataset path replace to data_path
  385. #cfg = Config.fromstring(cfg.dump().replace("ann_file='data/coco/annotations/instances_train2017.json',","ann_file='{}',".format(os.path.join(args.data_path,"annotations/instances_annotations.json"))), ".py")
  386. #cfg = Config.fromstring(cfg.dump().replace("img_prefix='data/coco/train2017/',","img_prefix='{}',".format(os.path.join(args.data_path,"images"))), ".py")
  387. # replace the classes num fit userdefine dataset
  388. #cfg = Config.fromstring(cfg.dump().replace("num_classes=80","num_classes={0}".format(len(coco_config.getCatIds()))), ".py")
  389. cfg.model.bbox_head.num_classes = len(coco_config.getCatIds())
  390. print(cfg.dump())
  391. # work_dir is determined in this priority: CLI > segment in file > filename
  392. if args.work_dir is not None:
  393. # update configs according to CLI args if args.work_dir is not None
  394. cfg.work_dir = args.work_dir
  395. elif cfg.get('work_dir', None) is None:
  396. # use config filename as default work_dir if cfg.work_dir is None
  397. cfg.work_dir = osp.join('./work_dirs',
  398. osp.splitext(osp.basename(args.config))[0])
  399. if args.resume_from is not None:
  400. cfg.resume_from = args.resume_from
  401. if args.gpu_ids is not None:
  402. cfg.gpu_ids = args.gpu_ids
  403. else:
  404. cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
  405. # init distributed env first, since logger depends on the dist info.
  406. if args.launcher == 'none':
  407. distributed = False
  408. else:
  409. distributed = True
  410. init_dist(args.launcher, **cfg.dist_params)
  411. # re-set gpu_ids with distributed training mode
  412. _, world_size = get_dist_info()
  413. cfg.gpu_ids = range(world_size)
  414. # create work_dir
  415. mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
  416. # dump config
  417. cfg.dump(osp.join(cfg.work_dir, 'config.py'))
  418. # init the logger before other steps
  419. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  420. log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
  421. logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
  422. # init the meta dict to record some important information such as
  423. # environment info and seed, which will be logged
  424. meta = dict()
  425. # log env info
  426. env_info_dict = collect_env()
  427. env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
  428. dash_line = '-' * 60 + '\n'
  429. logger.info('Environment info:\n' + dash_line + env_info + '\n' +
  430. dash_line)
  431. meta['env_info'] = env_info
  432. meta['config'] = cfg.pretty_text
  433. # log some basic info
  434. logger.info(f'Distributed training: {distributed}')
  435. logger.info(f'Config:\n{cfg.pretty_text}')
  436. # set random seeds
  437. #seed = init_random_seed(args.seed)
  438. seed = 965702173
  439. logger.info(f'Set random seed to {seed}, '
  440. f'deterministic: {args.deterministic}')
  441. set_random_seed(seed, deterministic=args.deterministic)
  442. #set_random_seed(seed, deterministic=True)
  443. cfg.seed = seed
  444. meta['seed'] = seed
  445. meta['exp_name'] = osp.basename(args.config)
  446. model = build_detector(
  447. cfg.model,
  448. train_cfg=cfg.get('train_cfg'),
  449. test_cfg=cfg.get('test_cfg'))
  450. model.init_weights()
  451. datasets = [build_dataset(cfg.data.train)]
  452. if len(cfg.workflow) == 2:
  453. val_dataset = copy.deepcopy(cfg.data.val)
  454. val_dataset.pipeline = cfg.data.train.pipeline
  455. datasets.append(build_dataset(val_dataset))
  456. if cfg.checkpoint_config is not None:
  457. # save mmdet version, config file content and class names in
  458. # checkpoints as meta data
  459. cfg.checkpoint_config.meta = dict(
  460. mmdet_version=__version__ + get_git_hash()[:7],
  461. CLASSES=datasets[0].CLASSES)
  462. # add an attribute for visualization convenience
  463. model.CLASSES = datasets[0].CLASSES
  464. train_detector(
  465. model,
  466. datasets,
  467. cfg,
  468. distributed=distributed,
  469. validate=(not args.no_validate),
  470. timestamp=timestamp,
  471. meta=meta)
  472. if args.shape is None:
  473. img_scale = cfg.test_pipeline[1]['img_scale'][0]
  474. print(img_scale)
  475. input_shape = (1, 3, img_scale[1], img_scale[0])
  476. elif len(args.shape) == 1:
  477. input_shape = (1, 3, args.shape[0], args.shape[0])
  478. elif len(args.shape) == 2:
  479. input_shape = (1, 3) + tuple(args.shape)
  480. else:
  481. raise ValueError('invalid input shape')
  482. # create onnx dir
  483. onnx_path = osp.join(args.work_dir)
  484. model = build_model_from_cfg(osp.join(args.work_dir, 'config.py'), osp.join(args.work_dir, "latest.pth"))
  485. input_img = osp.join(osp.dirname(__file__), 'demo.jpg')
  486. normalize_cfg = parse_normalize_cfg(cfg.test_pipeline)
  487. # convert model to onnx file
  488. pytorch2onnx(
  489. model,
  490. input_img,
  491. input_shape,
  492. normalize_cfg,
  493. output_file=osp.join(onnx_path,'model.onnx'),
  494. test_img=input_img)
  495. #启智平台
  496. shutil.copytree(osp.abspath(osp.join(osp.dirname(__file__),'../transformer/')), osp.join(args.work_dir, "transformer"))
  497. #shutil.copy(osp.join(args.train_work_dir, "config.py"), osp.join(args.work_dir, "config.py"))
  498. class_name_file = open(osp.join(args.work_dir, "class_names.txt"), 'w')
  499. for name in cfg.classes:
  500. class_name_file.write(name+'\n')
  501. shutil.copy(osp.abspath(osp.join(osp.dirname(__file__),'serve_desc.yaml')), osp.join(args.work_dir, "serve_desc.yaml"))
  502. if __name__ == '__main__':
  503. main()

No Description

Contributors (3)