You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mmdet2torchserve.py 3.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from argparse import ArgumentParser, Namespace
  3. from pathlib import Path
  4. from tempfile import TemporaryDirectory
  5. import mmcv
  6. try:
  7. from model_archiver.model_packaging import package_model
  8. from model_archiver.model_packaging_utils import ModelExportUtils
  9. except ImportError:
  10. package_model = None
  11. def mmdet2torchserve(
  12. config_file: str,
  13. checkpoint_file: str,
  14. output_folder: str,
  15. model_name: str,
  16. model_version: str = '1.0',
  17. force: bool = False,
  18. ):
  19. """Converts MMDetection model (config + checkpoint) to TorchServe `.mar`.
  20. Args:
  21. config_file:
  22. In MMDetection config format.
  23. The contents vary for each task repository.
  24. checkpoint_file:
  25. In MMDetection checkpoint format.
  26. The contents vary for each task repository.
  27. output_folder:
  28. Folder where `{model_name}.mar` will be created.
  29. The file created will be in TorchServe archive format.
  30. model_name:
  31. If not None, used for naming the `{model_name}.mar` file
  32. that will be created under `output_folder`.
  33. If None, `{Path(checkpoint_file).stem}` will be used.
  34. model_version:
  35. Model's version.
  36. force:
  37. If True, if there is an existing `{model_name}.mar`
  38. file under `output_folder` it will be overwritten.
  39. """
  40. mmcv.mkdir_or_exist(output_folder)
  41. config = mmcv.Config.fromfile(config_file)
  42. with TemporaryDirectory() as tmpdir:
  43. config.dump(f'{tmpdir}/config.py')
  44. args = Namespace(
  45. **{
  46. 'model_file': f'{tmpdir}/config.py',
  47. 'serialized_file': checkpoint_file,
  48. 'handler': f'{Path(__file__).parent}/mmdet_handler.py',
  49. 'model_name': model_name or Path(checkpoint_file).stem,
  50. 'version': model_version,
  51. 'export_path': output_folder,
  52. 'force': force,
  53. 'requirements_file': None,
  54. 'extra_files': None,
  55. 'runtime': 'python',
  56. 'archive_format': 'default'
  57. })
  58. manifest = ModelExportUtils.generate_manifest_json(args)
  59. package_model(args, manifest)
  60. def parse_args():
  61. parser = ArgumentParser(
  62. description='Convert MMDetection models to TorchServe `.mar` format.')
  63. parser.add_argument('config', type=str, help='config file path')
  64. parser.add_argument('checkpoint', type=str, help='checkpoint file path')
  65. parser.add_argument(
  66. '--output-folder',
  67. type=str,
  68. required=True,
  69. help='Folder where `{model_name}.mar` will be created.')
  70. parser.add_argument(
  71. '--model-name',
  72. type=str,
  73. default=None,
  74. help='If not None, used for naming the `{model_name}.mar`'
  75. 'file that will be created under `output_folder`.'
  76. 'If None, `{Path(checkpoint_file).stem}` will be used.')
  77. parser.add_argument(
  78. '--model-version',
  79. type=str,
  80. default='1.0',
  81. help='Number used for versioning.')
  82. parser.add_argument(
  83. '-f',
  84. '--force',
  85. action='store_true',
  86. help='overwrite the existing `{model_name}.mar`')
  87. args = parser.parse_args()
  88. return args
  89. if __name__ == '__main__':
  90. args = parse_args()
  91. if package_model is None:
  92. raise ImportError('`torch-model-archiver` is required.'
  93. 'Try: pip install torch-model-archiver')
  94. mmdet2torchserve(args.config, args.checkpoint, args.output_folder,
  95. args.model_name, args.model_version, args.force)

No Description

Contributors (2)