You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_deberta_tasks.py 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. import torch
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.models import Model
  6. from modelscope.models.nlp import DebertaV2ForMaskedLM
  7. from modelscope.models.nlp.deberta_v2 import (DebertaV2Tokenizer,
  8. DebertaV2TokenizerFast)
  9. from modelscope.pipelines import pipeline
  10. from modelscope.pipelines.nlp import FillMaskPipeline
  11. from modelscope.preprocessors import FillMaskPreprocessor
  12. from modelscope.utils.constant import Tasks
  13. from modelscope.utils.test_utils import test_level
  14. class DeBERTaV2TaskTest(unittest.TestCase):
  15. model_id_deberta = 'damo/nlp_debertav2_fill-mask_chinese-lite'
  16. ori_text = '你师父差得动你,你师父可差不动我。'
  17. test_input = '你师父差得动你,你师父可[MASK]不动我。'
  18. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  19. def test_run_by_direct_model_download(self):
  20. model_dir = snapshot_download(self.model_id_deberta)
  21. preprocessor = FillMaskPreprocessor(
  22. model_dir, first_sequence='sentence', second_sequence=None)
  23. model = DebertaV2ForMaskedLM.from_pretrained(model_dir)
  24. pipeline1 = FillMaskPipeline(model, preprocessor)
  25. pipeline2 = pipeline(
  26. Tasks.fill_mask, model=model, preprocessor=preprocessor)
  27. ori_text = self.ori_text
  28. test_input = self.test_input
  29. print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
  30. f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n')
  31. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  32. def test_run_with_model_from_modelhub(self):
  33. # sbert
  34. print(self.model_id_deberta)
  35. model = Model.from_pretrained(self.model_id_deberta)
  36. preprocessor = FillMaskPreprocessor(
  37. model.model_dir, first_sequence='sentence', second_sequence=None)
  38. pipeline_ins = pipeline(
  39. task=Tasks.fill_mask, model=model, preprocessor=preprocessor)
  40. print(
  41. f'\nori_text: {self.ori_text}\ninput: {self.test_input}\npipeline: '
  42. f'{pipeline_ins(self.test_input)}\n')
  43. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  44. def test_run_with_model_name(self):
  45. pipeline_ins = pipeline(
  46. task=Tasks.fill_mask, model=self.model_id_deberta)
  47. ori_text = self.ori_text
  48. test_input = self.test_input
  49. print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
  50. f'{pipeline_ins(test_input)}\n')
  51. if __name__ == '__main__':
  52. unittest.main()