You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_classification.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import tempfile
  3. import unittest
  4. import zipfile
  5. from pathlib import Path
  6. from ali_maas_datasets import PyDataset
  7. from maas_lib.fileio import File
  8. from maas_lib.models import Model
  9. from maas_lib.models.nlp import SequenceClassificationModel
  10. from maas_lib.pipelines import SequenceClassificationPipeline, pipeline
  11. from maas_lib.preprocessors import SequenceClassificationPreprocessor
  12. from maas_lib.utils.constant import Tasks
  13. class SequenceClassificationTest(unittest.TestCase):
  14. def predict(self, pipeline_ins: SequenceClassificationPipeline):
  15. from easynlp.appzoo import load_dataset
  16. set = load_dataset('glue', 'sst2')
  17. data = set['test']['sentence'][:3]
  18. results = pipeline_ins(data[0])
  19. print(results)
  20. results = pipeline_ins(data[1])
  21. print(results)
  22. print(data)
  23. def test_run(self):
  24. model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
  25. '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
  26. cache_path_str = r'.cache/easynlp/bert-base-sst2.zip'
  27. cache_path = Path(cache_path_str)
  28. if not cache_path.exists():
  29. cache_path.parent.mkdir(parents=True, exist_ok=True)
  30. cache_path.touch(exist_ok=True)
  31. with cache_path.open('wb') as ofile:
  32. ofile.write(File.read(model_url))
  33. with zipfile.ZipFile(cache_path_str, 'r') as zipf:
  34. zipf.extractall(cache_path.parent)
  35. path = r'.cache/easynlp/'
  36. model = SequenceClassificationModel(path)
  37. preprocessor = SequenceClassificationPreprocessor(
  38. path, first_sequence='sentence', second_sequence=None)
  39. pipeline1 = SequenceClassificationPipeline(model, preprocessor)
  40. self.predict(pipeline1)
  41. pipeline2 = pipeline(
  42. Tasks.text_classification, model=model, preprocessor=preprocessor)
  43. print(pipeline2('Hello world!'))
  44. def test_run_modelhub(self):
  45. model = Model.from_pretrained('damo/bert-base-sst2')
  46. preprocessor = SequenceClassificationPreprocessor(
  47. model.model_dir, first_sequence='sentence', second_sequence=None)
  48. pipeline_ins = pipeline(
  49. task=Tasks.text_classification,
  50. model=model,
  51. preprocessor=preprocessor)
  52. self.predict(pipeline_ins)
  53. def test_run_with_dataset(self):
  54. model = Model.from_pretrained('damo/bert-base-sst2')
  55. preprocessor = SequenceClassificationPreprocessor(
  56. model.model_dir, first_sequence='sentence', second_sequence=None)
  57. text_classification = pipeline(
  58. Tasks.text_classification, model=model, preprocessor=preprocessor)
  59. # loaded from huggingface dataset
  60. # TODO: add load_from parameter (an enum) LOAD_FROM.hugging_face
  61. # TODO: rename parameter as dataset_name and subset_name
  62. dataset = PyDataset.load('glue', name='sst2', target='sentence')
  63. result = text_classification(dataset)
  64. for i, r in enumerate(result):
  65. if i > 10:
  66. break
  67. print(r)
  68. if __name__ == '__main__':
  69. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展