1.修复token classification preprocessor finetune结果错误问题
2.修复word segmentation output 无用属性
3. 修复nlp preprocessor传use_fast错误
4. 修复torch model exporter bug
5. 修复文档撰写过程中发现trainer相关bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10573269
master
| @@ -128,7 +128,7 @@ class TorchModelExporter(Exporter): | |||
| args_list = list(args) | |||
| else: | |||
| args_list = [args] | |||
| if isinstance(args_list[-1], dict): | |||
| if isinstance(args_list[-1], Mapping): | |||
| args_dict = args_list[-1] | |||
| args_list = args_list[:-1] | |||
| n_nonkeyword = len(args_list) | |||
| @@ -284,9 +284,8 @@ class TorchModelExporter(Exporter): | |||
| 'Model property dummy_inputs must be set.') | |||
| dummy_inputs = collate_fn(dummy_inputs, device) | |||
| if isinstance(dummy_inputs, Mapping): | |||
| dummy_inputs = self._decide_input_format(model, dummy_inputs) | |||
| dummy_inputs_filter = [] | |||
| for _input in dummy_inputs: | |||
| for _input in self._decide_input_format(model, dummy_inputs): | |||
| if _input is not None: | |||
| dummy_inputs_filter.append(_input) | |||
| else: | |||
| @@ -491,17 +491,8 @@ TASK_OUTPUTS = { | |||
| # word segmentation result for single sample | |||
| # { | |||
| # "output": "今天 天气 不错 , 适合 出去 游玩" | |||
| # "labels": [ | |||
| # {'word': '今天', 'label': 'PROPN'}, | |||
| # {'word': '天气', 'label': 'PROPN'}, | |||
| # {'word': '不错', 'label': 'VERB'}, | |||
| # {'word': ',', 'label': 'NUM'}, | |||
| # {'word': '适合', 'label': 'NOUN'}, | |||
| # {'word': '出去', 'label': 'PART'}, | |||
| # {'word': '游玩', 'label': 'ADV'}, | |||
| # ] | |||
| # } | |||
| Tasks.word_segmentation: [OutputKeys.OUTPUT, OutputKeys.LABELS], | |||
| Tasks.word_segmentation: [OutputKeys.OUTPUT], | |||
| # TODO @wenmeng.zwm support list of result check | |||
| # named entity recognition result for single sample | |||
| @@ -109,13 +109,13 @@ class TokenClassificationPipeline(Pipeline): | |||
| chunk['span'] = text[chunk['start']:chunk['end']] | |||
| chunks.append(chunk) | |||
| # for cws output | |||
| # for cws outputs | |||
| if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
| spans = [ | |||
| chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
| ] | |||
| seg_result = ' '.join(spans) | |||
| outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||
| outputs = {OutputKeys.OUTPUT: seg_result} | |||
| # for ner outputs | |||
| else: | |||
| @@ -115,15 +115,15 @@ class WordSegmentationPipeline(Pipeline): | |||
| chunk['span'] = text[chunk['start']:chunk['end']] | |||
| chunks.append(chunk) | |||
| # for cws output | |||
| # for cws outputs | |||
| if len(chunks) > 0 and chunks[0]['type'] == 'cws': | |||
| spans = [ | |||
| chunk['span'] for chunk in chunks if chunk['span'].strip() | |||
| ] | |||
| seg_result = ' '.join(spans) | |||
| outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||
| outputs = {OutputKeys.OUTPUT: seg_result} | |||
| # for ner outpus | |||
| # for ner output | |||
| else: | |||
| outputs = {OutputKeys.OUTPUT: chunks} | |||
| return outputs | |||
| @@ -34,6 +34,7 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
| label=None, | |||
| label2id=None, | |||
| mode=ModeKeys.INFERENCE, | |||
| use_fast=None, | |||
| **kwargs): | |||
| """The NLP preprocessor base class. | |||
| @@ -45,14 +46,18 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
| label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping | |||
| if this mapping is not supplied. | |||
| mode: Run this preprocessor in either 'train'/'eval'/'inference' mode | |||
| use_fast: use the fast version of tokenizer | |||
| """ | |||
| self.model_dir = model_dir | |||
| self.first_sequence = first_sequence | |||
| self.second_sequence = second_sequence | |||
| self.label = label | |||
| self.use_fast = kwargs.pop('use_fast', None) | |||
| if self.use_fast is None and os.path.isfile( | |||
| self.use_fast = use_fast | |||
| if self.use_fast is None and model_dir is None: | |||
| self.use_fast = False | |||
| elif self.use_fast is None and os.path.isfile( | |||
| os.path.join(model_dir, 'tokenizer_config.json')): | |||
| with open(os.path.join(model_dir, 'tokenizer_config.json'), | |||
| 'r') as f: | |||
| @@ -61,8 +66,8 @@ class NLPBasePreprocessor(Preprocessor, ABC): | |||
| self.use_fast = False if self.use_fast is None else self.use_fast | |||
| self.label2id = label2id | |||
| if self.label2id is None: | |||
| self.label2id = parse_label_mapping(self.model_dir) | |||
| if self.label2id is None and model_dir is not None: | |||
| self.label2id = parse_label_mapping(model_dir) | |||
| super().__init__(mode, **kwargs) | |||
| @property | |||
| @@ -106,6 +111,7 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||
| label: str = 'label', | |||
| label2id: dict = None, | |||
| mode: str = ModeKeys.INFERENCE, | |||
| use_fast: bool = None, | |||
| **kwargs): | |||
| """The NLP tokenizer preprocessor base class. | |||
| @@ -122,11 +128,12 @@ class NLPTokenizerPreprocessorBase(NLPBasePreprocessor): | |||
| - config.json label2id/id2label | |||
| - label_mapping.json | |||
| mode: Run this preprocessor in either 'train'/'eval'/'inference' mode, the behavior may be different. | |||
| use_fast: use the fast version of tokenizer | |||
| kwargs: These kwargs will be directly fed into the tokenizer. | |||
| """ | |||
| super().__init__(model_dir, first_sequence, second_sequence, label, | |||
| label2id, mode) | |||
| label2id, mode, use_fast, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.tokenize_kwargs = kwargs | |||
| self.tokenizer = self.build_tokenizer(model_dir) | |||
| @@ -2,6 +2,7 @@ | |||
| from typing import Any, Dict, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Preprocessors | |||
| @@ -20,9 +21,7 @@ class WordSegmentationBlankSetToLabelPreprocessor(NLPBasePreprocessor): | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.first_sequence: str = kwargs.pop('first_sequence', 'tokens') | |||
| self.label = kwargs.pop('label', OutputKeys.LABELS) | |||
| def __call__(self, data: str) -> Union[Dict[str, Any], Tuple]: | |||
| @@ -80,10 +79,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| 'is_split_into_words', False) | |||
| if 'label2id' in kwargs: | |||
| kwargs.pop('label2id') | |||
| self.tokenize_kwargs = kwargs | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| @type_assert(object, (str, dict)) | |||
| def __call__(self, data: Union[dict, str]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| @@ -99,18 +97,24 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| text = None | |||
| labels_list = None | |||
| if isinstance(data, str): | |||
| # for inference inputs without label | |||
| text = data | |||
| self.tokenize_kwargs['add_special_tokens'] = False | |||
| elif isinstance(data, dict): | |||
| # for finetune inputs with label | |||
| text = data.get(self.first_sequence) | |||
| labels_list = data.get(self.label) | |||
| if isinstance(text, list): | |||
| self.tokenize_kwargs['is_split_into_words'] = True | |||
| input_ids = [] | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| if self.is_split_into_words: | |||
| for offset, token in enumerate(list(data)): | |||
| subtoken_ids = self.tokenizer.encode( | |||
| token, add_special_tokens=False) | |||
| token_type_ids = [] | |||
| if self.is_split_into_words and self._mode == ModeKeys.INFERENCE: | |||
| for offset, token in enumerate(list(text)): | |||
| subtoken_ids = self.tokenizer.encode(token, | |||
| **self.tokenize_kwargs) | |||
| if len(subtoken_ids) == 0: | |||
| subtoken_ids = [self.tokenizer.unk_token_id] | |||
| input_ids.extend(subtoken_ids) | |||
| @@ -119,10 +123,9 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| else: | |||
| if self.tokenizer.is_fast: | |||
| encodings = self.tokenizer( | |||
| text, | |||
| add_special_tokens=False, | |||
| return_offsets_mapping=True, | |||
| **self.tokenize_kwargs) | |||
| text, return_offsets_mapping=True, **self.tokenize_kwargs) | |||
| attention_mask = encodings['attention_mask'] | |||
| token_type_ids = encodings['token_type_ids'] | |||
| input_ids = encodings['input_ids'] | |||
| word_ids = encodings.word_ids() | |||
| for i in range(len(word_ids)): | |||
| @@ -143,69 +146,80 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( | |||
| text) | |||
| if len(input_ids) >= self.sequence_length - 2: | |||
| input_ids = input_ids[:self.sequence_length - 2] | |||
| label_mask = label_mask[:self.sequence_length - 2] | |||
| input_ids = [self.tokenizer.cls_token_id | |||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||
| label_mask = [0] + label_mask + [0] | |||
| attention_mask = [1] * len(input_ids) | |||
| offset_mapping = offset_mapping[:sum(label_mask)] | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| if len(input_ids) >= self.sequence_length - 2: | |||
| input_ids = input_ids[:self.sequence_length - 2] | |||
| label_mask = label_mask[:self.sequence_length - 2] | |||
| input_ids = [self.tokenizer.cls_token_id | |||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||
| label_mask = [0] + label_mask + [0] | |||
| attention_mask = [1] * len(input_ids) | |||
| offset_mapping = offset_mapping[:sum(label_mask)] | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| input_ids = torch.tensor(input_ids).unsqueeze(0) | |||
| attention_mask = torch.tensor(attention_mask).unsqueeze(0) | |||
| label_mask = torch.tensor( | |||
| label_mask, dtype=torch.bool).unsqueeze(0) | |||
| # the token classification | |||
| output = { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| # align the labels with tokenized text | |||
| if labels_list is not None: | |||
| assert self.label2id is not None | |||
| # Map that sends B-Xxx label to its I-Xxx counterpart | |||
| b_to_i_label = [] | |||
| label_enumerate_values = [ | |||
| k for k, v in sorted( | |||
| self.label2id.items(), key=lambda item: item[1]) | |||
| ] | |||
| for idx, label in enumerate(label_enumerate_values): | |||
| if label.startswith('B-') and label.replace( | |||
| 'B-', 'I-') in label_enumerate_values: | |||
| b_to_i_label.append( | |||
| label_enumerate_values.index( | |||
| label.replace('B-', 'I-'))) | |||
| else: | |||
| b_to_i_label.append(idx) | |||
| # the token classification | |||
| output = { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| else: | |||
| output = { | |||
| 'input_ids': input_ids, | |||
| 'token_type_ids': token_type_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| } | |||
| label_row = [self.label2id[lb] for lb in labels_list] | |||
| previous_word_idx = None | |||
| label_ids = [] | |||
| for word_idx in word_ids: | |||
| if word_idx is None: | |||
| label_ids.append(-100) | |||
| elif word_idx != previous_word_idx: | |||
| label_ids.append(label_row[word_idx]) | |||
| else: | |||
| if self.label_all_tokens: | |||
| label_ids.append(b_to_i_label[label_row[word_idx]]) | |||
| # align the labels with tokenized text | |||
| if labels_list is not None: | |||
| assert self.label2id is not None | |||
| # Map that sends B-Xxx label to its I-Xxx counterpart | |||
| b_to_i_label = [] | |||
| label_enumerate_values = [ | |||
| k for k, v in sorted( | |||
| self.label2id.items(), key=lambda item: item[1]) | |||
| ] | |||
| for idx, label in enumerate(label_enumerate_values): | |||
| if label.startswith('B-') and label.replace( | |||
| 'B-', 'I-') in label_enumerate_values: | |||
| b_to_i_label.append( | |||
| label_enumerate_values.index( | |||
| label.replace('B-', 'I-'))) | |||
| else: | |||
| b_to_i_label.append(idx) | |||
| label_row = [self.label2id[lb] for lb in labels_list] | |||
| previous_word_idx = None | |||
| label_ids = [] | |||
| for word_idx in word_ids: | |||
| if word_idx is None: | |||
| label_ids.append(-100) | |||
| previous_word_idx = word_idx | |||
| labels = label_ids | |||
| output['labels'] = labels | |||
| elif word_idx != previous_word_idx: | |||
| label_ids.append(label_row[word_idx]) | |||
| else: | |||
| if self.label_all_tokens: | |||
| label_ids.append(b_to_i_label[label_row[word_idx]]) | |||
| else: | |||
| label_ids.append(-100) | |||
| previous_word_idx = word_idx | |||
| labels = label_ids | |||
| output['labels'] = labels | |||
| output = { | |||
| k: np.array(v) if isinstance(v, list) else v | |||
| for k, v in output.items() | |||
| } | |||
| return output | |||
| def get_tokenizer_class(self): | |||
| @@ -18,7 +18,7 @@ class TextGenerationTrainer(NlpEpochBasedTrainer): | |||
| return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) | |||
| def evaluation_step(self, data): | |||
| model = self.model | |||
| model = self.model.module if self._dist else self.model | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| @@ -586,14 +586,16 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||
| preprocessor_mode=ModeKeys.TRAIN, | |||
| **model_args, | |||
| **self.train_keys, | |||
| mode=ModeKeys.TRAIN) | |||
| mode=ModeKeys.TRAIN, | |||
| use_fast=True) | |||
| eval_preprocessor = Preprocessor.from_pretrained( | |||
| self.model_dir, | |||
| cfg_dict=self.cfg, | |||
| preprocessor_mode=ModeKeys.EVAL, | |||
| **model_args, | |||
| **self.eval_keys, | |||
| mode=ModeKeys.EVAL) | |||
| mode=ModeKeys.EVAL, | |||
| use_fast=True) | |||
| return train_preprocessor, eval_preprocessor | |||
| @@ -876,7 +876,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| Subclass and override to inject custom behavior. | |||
| """ | |||
| model = self.model | |||
| model = self.model.module if self._dist else self.model | |||
| model.eval() | |||
| if is_parallel(model): | |||
| @@ -21,9 +21,10 @@ class TestModelOutput(unittest.TestCase): | |||
| self.assertEqual(outputs['logits'], torch.Tensor([1])) | |||
| self.assertEqual(outputs[0], torch.Tensor([1])) | |||
| self.assertEqual(outputs.logits, torch.Tensor([1])) | |||
| outputs.loss = torch.Tensor([2]) | |||
| logits, loss = outputs | |||
| self.assertEqual(logits, torch.Tensor([1])) | |||
| self.assertTrue(loss is None) | |||
| self.assertTrue(loss is not None) | |||
| if __name__ == '__main__': | |||
| @@ -87,7 +87,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| 'labels': label_enumerate_values, | |||
| 'first_sequence': 'first_sequence', | |||
| 'first_sequence': 'tokens', | |||
| 'label': 'labels', | |||
| } | |||
| } | |||