diff --git a/modelscope/models/nlp/palm_for_text_generation.py b/modelscope/models/nlp/palm_for_text_generation.py index ffba7265..e5799feb 100644 --- a/modelscope/models/nlp/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_for_text_generation.py @@ -7,7 +7,7 @@ from ..builder import MODELS __all__ = ['PalmForTextGeneration'] -@MODELS.register_module(Tasks.text_generation, module_name=r'palm') +@MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0') class PalmForTextGeneration(Model): def __init__(self, model_dir: str, *args, **kwargs): @@ -18,35 +18,26 @@ class PalmForTextGeneration(Model): model_cls (Optional[Any], optional): model loader, if None, use the default loader to load model weights, by default None. """ - from sofa import PalmTokenizer - super().__init__(model_dir, *args, **kwargs) self.model_dir = model_dir - from sofa.models.palm import PalmForConditionalGeneration, TextGenerator - tokenizer = kwargs.pop('tokenizer', - PalmTokenizer.from_pretrained(model_dir)) + from sofa.models.palm_v2 import PalmForConditionalGeneration, Translator model = PalmForConditionalGeneration.from_pretrained(model_dir) - self.generator = TextGenerator(model, tokenizer) + self.tokenizer = model.tokenizer + self.generator = Translator(model) def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model Args: - input (Dict[str, Any]): the preprocessed data + input (Dict[str, Tensor]): the preprocessed data Returns: - Dict[str, np.ndarray]: results + Dict[str, Tensor]: results Example: { - 'predictions': array([1]), # lable 0-negative 1-positive - 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), - 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), # tokens need to be decode by tokenizer } """ - encoder_inputs = [ - input['input_ids'], input['token_type_ids'], - input['attention_mask'] - ] - return self.generator(encoder_inputs) + return self.generator(**input) diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index c24a7c3e..6e2c791d 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -22,7 +22,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), Tasks.text_classification: ('bert-sentiment-analysis', 'damo/bert-base-sst2'), - Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), + Tasks.text_generation: ('palm2.0', + 'damo/nlp_palm2.0_text-generation_chinese-base'), Tasks.image_captioning: ('ofa', None), Tasks.image_generation: ('person-image-cartoon', diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 8b6bf8a9..881e7ea6 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -10,7 +10,7 @@ from ..builder import PIPELINES __all__ = ['TextGenerationPipeline'] -@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm') +@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0') class TextGenerationPipeline(Pipeline): def __init__(self, @@ -23,15 +23,16 @@ class TextGenerationPipeline(Pipeline): model (SequenceClassificationModel): a model instance preprocessor (SequenceClassificationPreprocessor): a preprocessor instance """ - sc_model = model if isinstance( + model = model if isinstance( model, PalmForTextGeneration) else Model.from_pretrained(model) if preprocessor is None: preprocessor = TextGenerationPreprocessor( - sc_model.model_dir, + model.model_dir, + model.tokenizer, first_sequence='sentence', second_sequence=None) - super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) - self.tokenizer = preprocessor.tokenizer + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = model.tokenizer def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: """process the prediction results @@ -42,17 +43,20 @@ class TextGenerationPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ + replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), + ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), + ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) + replace_tokens_roberta = ((r' +', ' '), ('', ''), ('', + ''), + ('', ''), ('', ''), ('', ' ')) - vocab_size = len(self.tokenizer.vocab) pred_list = inputs['predictions'] pred_ids = pred_list[0][0].cpu().numpy().tolist() - for j in range(len(pred_ids)): - if pred_ids[j] >= vocab_size: - pred_ids[j] = 100 - pred = self.tokenizer.convert_ids_to_tokens(pred_ids) - pred_string = ''.join(pred).replace( - '##', - '').split('[SEP]')[0].replace('[CLS]', - '').replace('[SEP]', - '').replace('[UNK]', '') + pred_string = self.tokenizer.decode(pred_ids) + for _old, _new in replace_tokens_bert: + pred_string = pred_string.replace(_old, _new) + pred_string.strip() + for _old, _new in replace_tokens_roberta: + pred_string = pred_string.replace(_old, _new) + pred_string.strip() return {'text': pred_string} diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 6a4a25fc..9bcaa87c 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -115,17 +115,15 @@ class SequenceClassificationPreprocessor(Preprocessor): return rst -@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm') +@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0') class TextGenerationPreprocessor(Preprocessor): - def __init__(self, model_dir: str, *args, **kwargs): + def __init__(self, model_dir: str, tokenizer, *args, **kwargs): """preprocess the data using the vocab.txt from the `model_dir` path Args: model_dir (str): model path """ - from sofa import PalmTokenizer - super().__init__(*args, **kwargs) self.model_dir: str = model_dir @@ -134,7 +132,7 @@ class TextGenerationPreprocessor(Preprocessor): self.second_sequence: str = kwargs.pop('second_sequence', 'second_sequence') self.sequence_length: int = kwargs.pop('sequence_length', 128) - self.tokenizer = PalmTokenizer.from_pretrained(model_dir) + self.tokenizer = tokenizer @type_assert(object, str) def __call__(self, data: str) -> Dict[str, Any]: @@ -153,7 +151,7 @@ class TextGenerationPreprocessor(Preprocessor): new_data = {self.first_sequence: data} # preprocess the data for the model input - rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []} + rst = {'input_ids': [], 'attention_mask': []} max_seq_length = self.sequence_length @@ -168,7 +166,6 @@ class TextGenerationPreprocessor(Preprocessor): rst['input_ids'].append(feature['input_ids']) rst['attention_mask'].append(feature['attention_mask']) - rst['token_type_ids'].append(feature['token_type_ids']) return {k: torch.tensor(v) for k, v in rst.items()} diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 8de83798..4e146a81 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1 +1 @@ -https://alinlp.alibaba-inc.com/pypi/sofa-1.0.1.3-py3-none-any.whl +https://alinlp.alibaba-inc.com/pypi/sofa-1.0.2-py3-none-any.whl diff --git a/requirements/runtime.txt b/requirements/runtime.txt index dd5616a2..e97352aa 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,7 +1,7 @@ addict datasets easydict -https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.2.dev0-py3-none-any.whl +https://mindscope.oss-cn-hangzhou.aliyuncs.com/sdklib/maas_hub-0.2.4.dev0-py3-none-any.whl numpy opencv-python-headless Pillow>=6.2.0 diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index 39d57ff7..fbdd165f 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -12,43 +12,67 @@ from modelscope.utils.test_utils import test_level class TextGenerationTest(unittest.TestCase): - model_id = 'damo/nlp_palm_text-generation_chinese' - input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" - input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" + model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base' + model_id_en = 'damo/nlp_palm2.0_text-generation_english-base' + input_zh = """ + 本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方: + 1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代 + """ + input_en = """ + The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started + her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders , + 54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges + of paedophilia against nine children because he has dementia . Today , newly-released documents + revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years . + And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her + pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 . + """ @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): - cache_path = snapshot_download(self.model_id) - preprocessor = TextGenerationPreprocessor( - cache_path, first_sequence='sentence', second_sequence=None) - model = PalmForTextGeneration( - cache_path, tokenizer=preprocessor.tokenizer) - pipeline1 = TextGenerationPipeline(model, preprocessor) - pipeline2 = pipeline( - Tasks.text_generation, model=model, preprocessor=preprocessor) - print(f'input: {self.input1}\npipeline1: {pipeline1(self.input1)}') - print() - print(f'input: {self.input2}\npipeline2: {pipeline2(self.input2)}') + for model_id, input in ((self.model_id_zh, self.input_zh), + (self.model_id_en, self.input_en)): + cache_path = snapshot_download(model_id) + model = PalmForTextGeneration(cache_path) + preprocessor = TextGenerationPreprocessor( + cache_path, + model.tokenizer, + first_sequence='sentence', + second_sequence=None) + pipeline1 = TextGenerationPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.text_generation, model=model, preprocessor=preprocessor) + print( + f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}' + ) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id) - preprocessor = TextGenerationPreprocessor( - model.model_dir, first_sequence='sentence', second_sequence=None) - pipeline_ins = pipeline( - task=Tasks.text_generation, model=model, preprocessor=preprocessor) - print(pipeline_ins(self.input1)) + for model_id, input in ((self.model_id_zh, self.input_zh), + (self.model_id_en, self.input_en)): + model = Model.from_pretrained(model_id) + preprocessor = TextGenerationPreprocessor( + model.model_dir, + model.tokenizer, + first_sequence='sentence', + second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.text_generation, + model=model, + preprocessor=preprocessor) + print(pipeline_ins(input)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_name(self): - pipeline_ins = pipeline( - task=Tasks.text_generation, model=self.model_id) - print(pipeline_ins(self.input2)) + for model_id, input in ((self.model_id_zh, self.input_zh), + (self.model_id_en, self.input_en)): + pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id) + print(pipeline_ins(input)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.text_generation) - print(pipeline_ins(self.input2)) + print(pipeline_ins(self.input_zh)) if __name__ == '__main__':