You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_classification.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import shutil
  3. import unittest
  4. import zipfile
  5. from pathlib import Path
  6. from maas_lib.fileio import File
  7. from maas_lib.models import Model
  8. from maas_lib.models.nlp import BertForSequenceClassification
  9. from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util
  10. from maas_lib.preprocessors import SequenceClassificationPreprocessor
  11. from maas_lib.pydatasets import PyDataset
  12. from maas_lib.utils.constant import Tasks
  13. class SequenceClassificationTest(unittest.TestCase):
  14. def setUp(self) -> None:
  15. self.model_id = 'damo/bert-base-sst2'
  16. # switch to False if downloading everytime is not desired
  17. purge_cache = True
  18. if purge_cache:
  19. shutil.rmtree(
  20. util.get_model_cache_dir(self.model_id), ignore_errors=True)
  21. def predict(self, pipeline_ins: SequenceClassificationPipeline):
  22. from easynlp.appzoo import load_dataset
  23. set = load_dataset('glue', 'sst2')
  24. data = set['test']['sentence'][:3]
  25. results = pipeline_ins(data[0])
  26. print(results)
  27. results = pipeline_ins(data[1])
  28. print(results)
  29. print(data)
  30. def printDataset(self, dataset: PyDataset):
  31. for i, r in enumerate(dataset):
  32. if i > 10:
  33. break
  34. print(r)
  35. def test_run(self):
  36. model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
  37. '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
  38. cache_path_str = r'.cache/easynlp/bert-base-sst2.zip'
  39. cache_path = Path(cache_path_str)
  40. if not cache_path.exists():
  41. cache_path.parent.mkdir(parents=True, exist_ok=True)
  42. cache_path.touch(exist_ok=True)
  43. with cache_path.open('wb') as ofile:
  44. ofile.write(File.read(model_url))
  45. with zipfile.ZipFile(cache_path_str, 'r') as zipf:
  46. zipf.extractall(cache_path.parent)
  47. path = r'.cache/easynlp/'
  48. model = BertForSequenceClassification(path)
  49. preprocessor = SequenceClassificationPreprocessor(
  50. path, first_sequence='sentence', second_sequence=None)
  51. pipeline1 = SequenceClassificationPipeline(model, preprocessor)
  52. self.predict(pipeline1)
  53. pipeline2 = pipeline(
  54. Tasks.text_classification, model=model, preprocessor=preprocessor)
  55. print(pipeline2('Hello world!'))
  56. def test_run_with_model_from_modelhub(self):
  57. model = Model.from_pretrained(self.model_id)
  58. preprocessor = SequenceClassificationPreprocessor(
  59. model.model_dir, first_sequence='sentence', second_sequence=None)
  60. pipeline_ins = pipeline(
  61. task=Tasks.text_classification,
  62. model=model,
  63. preprocessor=preprocessor)
  64. self.predict(pipeline_ins)
  65. def test_run_with_model_name(self):
  66. text_classification = pipeline(
  67. task=Tasks.text_classification, model=self.model_id)
  68. result = text_classification(
  69. PyDataset.load('glue', name='sst2', target='sentence'))
  70. self.printDataset(result)
  71. def test_run_with_dataset(self):
  72. model = Model.from_pretrained(self.model_id)
  73. preprocessor = SequenceClassificationPreprocessor(
  74. model.model_dir, first_sequence='sentence', second_sequence=None)
  75. text_classification = pipeline(
  76. Tasks.text_classification, model=model, preprocessor=preprocessor)
  77. # loaded from huggingface dataset
  78. # TODO: add load_from parameter (an enum) LOAD_FROM.hugging_face
  79. # TODO: rename parameter as dataset_name and subset_name
  80. dataset = PyDataset.load('glue', name='sst2', target='sentence')
  81. result = text_classification(dataset)
  82. self.printDataset(result)
  83. if __name__ == '__main__':
  84. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展