You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ms_dataset.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import unittest
  2. from modelscope.models import Model
  3. from modelscope.msdatasets import MsDataset
  4. from modelscope.preprocessors import SequenceClassificationPreprocessor
  5. from modelscope.preprocessors.base import Preprocessor
  6. from modelscope.utils.test_utils import require_tf, require_torch, test_level
  7. class ImgPreprocessor(Preprocessor):
  8. def __init__(self, *args, **kwargs):
  9. super().__init__(*args, **kwargs)
  10. self.path_field = kwargs.pop('image_path', 'image_path')
  11. self.width = kwargs.pop('width', 'width')
  12. self.height = kwargs.pop('height', 'width')
  13. def __call__(self, data):
  14. import cv2
  15. image_path = data.get(self.path_field)
  16. if not image_path:
  17. return None
  18. img = cv2.imread(image_path)
  19. return {
  20. 'image':
  21. cv2.resize(img,
  22. (data.get(self.height, 128), data.get(self.width, 128)))
  23. }
  24. class MsDatasetTest(unittest.TestCase):
  25. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  26. def test_ms_csv_basic(self):
  27. ms_ds_train = MsDataset.load(
  28. 'afqmc_small', namespace='userxiaoming', split='train')
  29. print(next(iter(ms_ds_train)))
  30. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  31. def test_ds_basic(self):
  32. ms_ds_full = MsDataset.load(
  33. 'xcopa', subset_name='translation-et', namespace='damotest')
  34. ms_ds = MsDataset.load(
  35. 'xcopa',
  36. subset_name='translation-et',
  37. namespace='damotest',
  38. split='test')
  39. print(next(iter(ms_ds_full['test'])))
  40. print(next(iter(ms_ds)))
  41. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  42. @require_torch
  43. def test_to_torch_dataset_text(self):
  44. model_id = 'damo/bert-base-sst2'
  45. nlp_model = Model.from_pretrained(model_id)
  46. preprocessor = SequenceClassificationPreprocessor(
  47. nlp_model.model_dir,
  48. first_sequence='premise',
  49. second_sequence=None)
  50. ms_ds_train = MsDataset.load(
  51. 'xcopa',
  52. subset_name='translation-et',
  53. namespace='damotest',
  54. split='test')
  55. pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
  56. import torch
  57. dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
  58. print(next(iter(dataloader)))
  59. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  60. @require_tf
  61. def test_to_tf_dataset_text(self):
  62. import tensorflow as tf
  63. tf.compat.v1.enable_eager_execution()
  64. model_id = 'damo/bert-base-sst2'
  65. nlp_model = Model.from_pretrained(model_id)
  66. preprocessor = SequenceClassificationPreprocessor(
  67. nlp_model.model_dir,
  68. first_sequence='premise',
  69. second_sequence=None)
  70. ms_ds_train = MsDataset.load(
  71. 'xcopa',
  72. subset_name='translation-et',
  73. namespace='damotest',
  74. split='test')
  75. tf_dataset = ms_ds_train.to_tf_dataset(
  76. batch_size=5,
  77. shuffle=True,
  78. preprocessors=preprocessor,
  79. drop_remainder=True)
  80. print(next(iter(tf_dataset)))
  81. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  82. @require_torch
  83. def test_to_torch_dataset_img(self):
  84. ms_image_train = MsDataset.load(
  85. 'fixtures_image_utils', namespace='damotest', split='test')
  86. pt_dataset = ms_image_train.to_torch_dataset(
  87. preprocessors=ImgPreprocessor(image_path='file'))
  88. import torch
  89. dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
  90. print(next(iter(dataloader)))
  91. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  92. @require_tf
  93. def test_to_tf_dataset_img(self):
  94. import tensorflow as tf
  95. tf.compat.v1.enable_eager_execution()
  96. ms_image_train = MsDataset.load(
  97. 'fixtures_image_utils', namespace='damotest', split='test')
  98. tf_dataset = ms_image_train.to_tf_dataset(
  99. batch_size=5,
  100. shuffle=True,
  101. preprocessors=ImgPreprocessor(image_path='file'),
  102. drop_remainder=True,
  103. )
  104. print(next(iter(tf_dataset)))
  105. if __name__ == '__main__':
  106. unittest.main()