You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train-checkpoint.py 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import copy
  4. import os
  5. import os.path as osp
  6. import time
  7. import warnings
  8. import shutil
  9. import sys
  10. path = os.path.dirname(os.path.dirname(__file__))
  11. print(path)
  12. sys.path.append("/home/kaijie-tang/userdata/qizhi/code_test")
  13. #os.environ['RANK'] = "0"
  14. #os.environ['WORLD_SIZE'] = "8"
  15. #os.environ['MASTER_ADDR'] = "localhost"
  16. #os.environ['MASTER_PORT'] = "1234"
  17. import mmcv
  18. import torch
  19. #os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
  20. from mmcv import Config, DictAction
  21. from mmcv.runner import get_dist_info, init_dist
  22. from mmcv.utils import get_git_hash
  23. from pycocotools.coco import COCO
  24. from mmdet import __version__
  25. from mmdet.apis import init_random_seed, set_random_seed, train_detector
  26. from mmdet.datasets import build_dataset
  27. from mmdet.models import build_detector
  28. from mmdet.utils import collect_env, get_root_logger
  29. def parse_args():
  30. parser = argparse.ArgumentParser(description='Train a detector')
  31. parser.add_argument('--config', default='/home/kaijie-tang/userdata/qizhi/code_test/configs/AD_mlops/AD_mlops_test18.py', help='train config file path')
  32. parser.add_argument('--work-dir',default='/home/kaijie-tang/outputs/', help='the dir to save logs and models')
  33. parser.add_argument(
  34. '--resume-from', help='the checkpoint file to resume from')
  35. parser.add_argument(
  36. '--no-validate',
  37. action='store_true',
  38. help='whether not to evaluate the checkpoint during training')
  39. group_gpus = parser.add_mutually_exclusive_group()
  40. group_gpus.add_argument(
  41. '--gpus',
  42. type=int,
  43. help='number of gpus to use '
  44. '(only applicable to non-distributed training)')
  45. group_gpus.add_argument(
  46. '--gpu-ids',
  47. type=int,
  48. nargs='+',
  49. help='ids of gpus to use '
  50. '(only applicable to non-distributed training)')
  51. parser.add_argument('--seed', type=int, default=None, help='random seed')
  52. parser.add_argument(
  53. '--deterministic',
  54. action='store_true',
  55. help='whether to set deterministic options for CUDNN backend.')
  56. parser.add_argument(
  57. '--options',
  58. nargs='+',
  59. action=DictAction,
  60. help='override some settings in the used config, the key-value pair '
  61. 'in xxx=yyy format will be merged into config file (deprecate), '
  62. 'change to --cfg-options instead.')
  63. parser.add_argument(
  64. '--cfg-options',
  65. nargs='+',
  66. action=DictAction,
  67. help='override some settings in the used config, the key-value pair '
  68. 'in xxx=yyy format will be merged into config file. If the value to '
  69. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  70. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  71. 'Note that the quotation marks are necessary and that no white space '
  72. 'is allowed.')
  73. parser.add_argument(
  74. '--launcher',
  75. choices=['none', 'pytorch', 'slurm', 'mpi'],
  76. default='none',
  77. help='job launcher')
  78. parser.add_argument('--local_rank', type=int, default=0)
  79. parser.add_argument(
  80. '--data-path', default='/home/kaijie-tang/adhub/coco2017_train/0.1', help='dataset path')
  81. parser.add_argument(
  82. '--batchsize',
  83. type=int,
  84. default=8,
  85. help='training batch size')
  86. parser.add_argument(
  87. '--epoch',
  88. type=int,
  89. default=5,
  90. help='training epoch')
  91. parser.add_argument(
  92. '--warmup_iters',
  93. type=int,
  94. default=500,
  95. help='training warmup_iters')
  96. parser.add_argument(
  97. '--lr',
  98. type=float,
  99. default=0.001,
  100. help='learning rate')
  101. parser.add_argument('--train_image_size',
  102. type=list,
  103. default=[(100, 100)],
  104. help='train image size')
  105. parser.add_argument('--test_image_size',
  106. type=list,
  107. default=[(100, 100)],
  108. help='test image size')
  109. args = parser.parse_args()
  110. if 'LOCAL_RANK' not in os.environ:
  111. os.environ['LOCAL_RANK'] = str(args.local_rank)
  112. if args.options and args.cfg_options:
  113. raise ValueError(
  114. '--options and --cfg-options cannot be both '
  115. 'specified, --options is deprecated in favor of --cfg-options')
  116. if args.options:
  117. warnings.warn('--options is deprecated in favor of --cfg-options')
  118. args.cfg_options = args.options
  119. return args
  120. def main():
  121. args = parse_args()
  122. cfg = Config.fromfile(args.config)
  123. if args.cfg_options is not None:
  124. cfg.merge_from_dict(args.cfg_options)
  125. # import modules from string list.
  126. if cfg.get('custom_imports', None):
  127. from mmcv.utils import import_modules_from_strings
  128. import_modules_from_strings(**cfg['custom_imports'])
  129. # set cudnn_benchmark
  130. if cfg.get('cudnn_benchmark', False):
  131. torch.backends.cudnn.benchmark = True
  132. if args.batchsize is not None:
  133. cfg.data.samples_per_gpu = args.batchsize
  134. if args.epoch is not None:
  135. cfg.runner.max_epochs = args.epoch
  136. if args.warmup_iters is not None:
  137. cfg.lr_config.warmup_iters = args.warmup_iters
  138. if args.lr is not None:
  139. cfg.optimizer.lr = args.lr
  140. '''if args.train_image_size is not None:
  141. cfg.train_pipeline[2].img_scale = args.train_image_size
  142. cfg.data.train.dataset.pipeline[2].img_scale = args.train_image_size'''
  143. if args.test_image_size is not None:
  144. cfg.test_pipeline[1].img_scale = args.test_image_size
  145. cfg.data.val.pipeline[1].img_scale = args.test_image_size
  146. cfg.data.test.pipeline[1].img_scale = args.test_image_size
  147. #if on platform, change the classnum fit the user define dataset
  148. if args.data_path is not None:
  149. coco_config=COCO(os.path.join(args.data_path,"annotations/instances_annotations.json"))
  150. cfg.data.train.dataset.img_prefix = os.path.join(args.data_path,"images")
  151. cfg.data.train.dataset.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
  152. cfg.data.val.img_prefix = os.path.join(args.data_path,"images")
  153. cfg.data.val.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
  154. cfg.data.test.img_prefix = os.path.join(args.data_path,"images")
  155. cfg.data.test.ann_file = os.path.join(args.data_path,"annotations/instances_annotations.json")
  156. cfg.classes = ()
  157. for cat in coco_config.cats.values():
  158. cfg.classes = cfg.classes + tuple([cat['name']])
  159. cfg.data.train.dataset.classes = cfg.classes
  160. cfg.data.val.classes = cfg.classes
  161. cfg.data.test.classes = cfg.classes
  162. #some model will RepeatDataset to speed up training, make sure all dataset path replace to data_path
  163. #cfg = Config.fromstring(cfg.dump().replace("ann_file='data/coco/annotations/instances_train2017.json',","ann_file='{}',".format(os.path.join(args.data_path,"annotations/instances_annotations.json"))), ".py")
  164. #cfg = Config.fromstring(cfg.dump().replace("img_prefix='data/coco/train2017/',","img_prefix='{}',".format(os.path.join(args.data_path,"images"))), ".py")
  165. # replace the classes num fit userdefine dataset
  166. #cfg = Config.fromstring(cfg.dump().replace("num_classes=80","num_classes={0}".format(len(coco_config.getCatIds()))), ".py")
  167. cfg.model.bbox_head.num_classes = len(coco_config.getCatIds())
  168. print(cfg.dump())
  169. # work_dir is determined in this priority: CLI > segment in file > filename
  170. if args.work_dir is not None:
  171. # update configs according to CLI args if args.work_dir is not None
  172. cfg.work_dir = args.work_dir
  173. elif cfg.get('work_dir', None) is None:
  174. # use config filename as default work_dir if cfg.work_dir is None
  175. cfg.work_dir = osp.join('./work_dirs',
  176. osp.splitext(osp.basename(args.config))[0])
  177. if args.resume_from is not None:
  178. cfg.resume_from = args.resume_from
  179. if args.gpu_ids is not None:
  180. cfg.gpu_ids = args.gpu_ids
  181. else:
  182. cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
  183. # init distributed env first, since logger depends on the dist info.
  184. if args.launcher == 'none':
  185. distributed = False
  186. else:
  187. distributed = True
  188. init_dist(args.launcher, **cfg.dist_params)
  189. # re-set gpu_ids with distributed training mode
  190. _, world_size = get_dist_info()
  191. cfg.gpu_ids = range(world_size)
  192. # create work_dir
  193. mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
  194. # dump config
  195. cfg.dump(osp.join(cfg.work_dir, 'config.py'))
  196. # init the logger before other steps
  197. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  198. log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
  199. logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
  200. # init the meta dict to record some important information such as
  201. # environment info and seed, which will be logged
  202. meta = dict()
  203. # log env info
  204. env_info_dict = collect_env()
  205. env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
  206. dash_line = '-' * 60 + '\n'
  207. logger.info('Environment info:\n' + dash_line + env_info + '\n' +
  208. dash_line)
  209. meta['env_info'] = env_info
  210. meta['config'] = cfg.pretty_text
  211. # log some basic info
  212. logger.info(f'Distributed training: {distributed}')
  213. logger.info(f'Config:\n{cfg.pretty_text}')
  214. # set random seeds
  215. #seed = init_random_seed(args.seed)
  216. seed = 965702173
  217. logger.info(f'Set random seed to {seed}, '
  218. f'deterministic: {args.deterministic}')
  219. set_random_seed(seed, deterministic=args.deterministic)
  220. #set_random_seed(seed, deterministic=True)
  221. cfg.seed = seed
  222. meta['seed'] = seed
  223. meta['exp_name'] = osp.basename(args.config)
  224. model = build_detector(
  225. cfg.model,
  226. train_cfg=cfg.get('train_cfg'),
  227. test_cfg=cfg.get('test_cfg'))
  228. model.init_weights()
  229. datasets = [build_dataset(cfg.data.train)]
  230. if len(cfg.workflow) == 2:
  231. val_dataset = copy.deepcopy(cfg.data.val)
  232. val_dataset.pipeline = cfg.data.train.pipeline
  233. datasets.append(build_dataset(val_dataset))
  234. if cfg.checkpoint_config is not None:
  235. # save mmdet version, config file content and class names in
  236. # checkpoints as meta data
  237. cfg.checkpoint_config.meta = dict(
  238. mmdet_version=__version__ + get_git_hash()[:7],
  239. CLASSES=datasets[0].CLASSES)
  240. # add an attribute for visualization convenience
  241. model.CLASSES = datasets[0].CLASSES
  242. train_detector(
  243. model,
  244. datasets,
  245. cfg,
  246. distributed=distributed,
  247. validate=(not args.no_validate),
  248. timestamp=timestamp,
  249. meta=meta)
  250. #启智平台
  251. shutil.copytree(osp.abspath(osp.join(osp.dirname(__file__),'../../transformer/')), osp.join(args.work_dir, "transformer"))
  252. shutil.copy(osp.join(args.train_work_dir, "config.py"), osp.join(args.work_dir, "config.py"))
  253. class_name_file = open(osp.join(args.work_dir, "class_names.txt"), 'w')
  254. for name in cfg.classes:
  255. class_name_file.write(name+'\n')
  256. shutil.copy(osp.abspath(osp.join(osp.dirname(__file__),'serve_desc.yaml')), osp.join(args.work_dir, "serve_desc.yaml"))
  257. if __name__ == '__main__':
  258. main()

No Description

Contributors (3)