You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_instance_segmentation_trainer.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. from functools import partial
  8. from modelscope.hub.snapshot_download import snapshot_download
  9. from modelscope.metainfo import Trainers
  10. from modelscope.models.cv.image_instance_segmentation import \
  11. CascadeMaskRCNNSwinModel
  12. from modelscope.msdatasets import MsDataset
  13. from modelscope.msdatasets.task_datasets import \
  14. ImageInstanceSegmentationCocoDataset
  15. from modelscope.trainers import build_trainer
  16. from modelscope.utils.config import Config, ConfigDict
  17. from modelscope.utils.constant import ModelFile
  18. from modelscope.utils.test_utils import test_level
  19. class TestImageInstanceSegmentationTrainer(unittest.TestCase):
  20. model_id = 'damo/cv_swin-b_image-instance-segmentation_coco'
  21. def setUp(self):
  22. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  23. cache_path = snapshot_download(self.model_id)
  24. config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
  25. cfg = Config.from_file(config_path)
  26. max_epochs = cfg.train.max_epochs
  27. samples_per_gpu = cfg.train.dataloader.batch_size_per_gpu
  28. try:
  29. train_data_cfg = cfg.dataset.train
  30. val_data_cfg = cfg.dataset.val
  31. except Exception:
  32. train_data_cfg = None
  33. val_data_cfg = None
  34. if train_data_cfg is None:
  35. # use default toy data
  36. train_data_cfg = ConfigDict(
  37. name='pets_small',
  38. split='train',
  39. classes=('Cat', 'Dog'),
  40. folder_name='Pets',
  41. test_mode=False)
  42. if val_data_cfg is None:
  43. val_data_cfg = ConfigDict(
  44. name='pets_small',
  45. split='validation',
  46. classes=('Cat', 'Dog'),
  47. folder_name='Pets',
  48. test_mode=True)
  49. self.train_dataset = MsDataset.load(
  50. dataset_name=train_data_cfg.name,
  51. split=train_data_cfg.split,
  52. classes=train_data_cfg.classes,
  53. folder_name=train_data_cfg.folder_name,
  54. test_mode=train_data_cfg.test_mode)
  55. assert self.train_dataset.config_kwargs[
  56. 'classes'] == train_data_cfg.classes
  57. assert next(
  58. iter(self.train_dataset.config_kwargs['split_config'].values()))
  59. self.eval_dataset = MsDataset.load(
  60. dataset_name=val_data_cfg.name,
  61. split=val_data_cfg.split,
  62. classes=val_data_cfg.classes,
  63. folder_name=val_data_cfg.folder_name,
  64. test_mode=val_data_cfg.test_mode)
  65. assert self.eval_dataset.config_kwargs[
  66. 'classes'] == val_data_cfg.classes
  67. assert next(
  68. iter(self.eval_dataset.config_kwargs['split_config'].values()))
  69. from mmcv.parallel import collate
  70. self.collate_fn = partial(collate, samples_per_gpu=samples_per_gpu)
  71. self.max_epochs = max_epochs
  72. self.tmp_dir = tempfile.TemporaryDirectory().name
  73. if not os.path.exists(self.tmp_dir):
  74. os.makedirs(self.tmp_dir)
  75. def tearDown(self):
  76. shutil.rmtree(self.tmp_dir)
  77. super().tearDown()
  78. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  79. def test_trainer(self):
  80. kwargs = dict(
  81. model=self.model_id,
  82. data_collator=self.collate_fn,
  83. train_dataset=self.train_dataset,
  84. eval_dataset=self.eval_dataset,
  85. work_dir=self.tmp_dir)
  86. trainer = build_trainer(
  87. name=Trainers.image_instance_segmentation, default_args=kwargs)
  88. trainer.train()
  89. results_files = os.listdir(self.tmp_dir)
  90. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  91. for i in range(self.max_epochs):
  92. self.assertIn(f'epoch_{i+1}.pth', results_files)
  93. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  94. def test_trainer_with_model_and_args(self):
  95. tmp_dir = tempfile.TemporaryDirectory().name
  96. if not os.path.exists(tmp_dir):
  97. os.makedirs(tmp_dir)
  98. cache_path = snapshot_download(self.model_id)
  99. model = CascadeMaskRCNNSwinModel.from_pretrained(cache_path)
  100. kwargs = dict(
  101. cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
  102. model=model,
  103. data_collator=self.collate_fn,
  104. train_dataset=self.train_dataset,
  105. eval_dataset=self.eval_dataset,
  106. work_dir=self.tmp_dir)
  107. trainer = build_trainer(
  108. name=Trainers.image_instance_segmentation, default_args=kwargs)
  109. trainer.train()
  110. results_files = os.listdir(self.tmp_dir)
  111. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  112. for i in range(self.max_epochs):
  113. self.assertIn(f'epoch_{i+1}.pth', results_files)
  114. if __name__ == '__main__':
  115. unittest.main()