You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_denoise_trainer.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
  8. from modelscope.msdatasets.image_denoise_data import PairedImageDataset
  9. from modelscope.trainers import build_trainer
  10. from modelscope.utils.config import Config
  11. from modelscope.utils.constant import ModelFile
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import test_level
  14. logger = get_logger()
  15. class ImageDenoiseTrainerTest(unittest.TestCase):
  16. def setUp(self):
  17. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  18. self.tmp_dir = tempfile.TemporaryDirectory().name
  19. if not os.path.exists(self.tmp_dir):
  20. os.makedirs(self.tmp_dir)
  21. self.model_id = 'damo/cv_nafnet_image-denoise_sidd'
  22. self.cache_path = snapshot_download(self.model_id)
  23. self.config = Config.from_file(
  24. os.path.join(self.cache_path, ModelFile.CONFIGURATION))
  25. self.dataset_train = PairedImageDataset(
  26. self.config.dataset, self.cache_path, is_train=True)
  27. self.dataset_val = PairedImageDataset(
  28. self.config.dataset, self.cache_path, is_train=False)
  29. def tearDown(self):
  30. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  31. super().tearDown()
  32. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  33. def test_trainer(self):
  34. kwargs = dict(
  35. model=self.model_id,
  36. train_dataset=self.dataset_train,
  37. eval_dataset=self.dataset_val,
  38. work_dir=self.tmp_dir)
  39. trainer = build_trainer(default_args=kwargs)
  40. trainer.train()
  41. results_files = os.listdir(self.tmp_dir)
  42. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  43. for i in range(2):
  44. self.assertIn(f'epoch_{i+1}.pth', results_files)
  45. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  46. def test_trainer_with_model_and_args(self):
  47. model = NAFNetForImageDenoise.from_pretrained(self.cache_path)
  48. kwargs = dict(
  49. cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
  50. model=model,
  51. train_dataset=self.dataset_train,
  52. eval_dataset=self.dataset_val,
  53. max_epochs=2,
  54. work_dir=self.tmp_dir)
  55. trainer = build_trainer(default_args=kwargs)
  56. trainer.train()
  57. results_files = os.listdir(self.tmp_dir)
  58. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  59. for i in range(2):
  60. self.assertIn(f'epoch_{i+1}.pth', results_files)
  61. if __name__ == '__main__':
  62. unittest.main()