1. Fix: ws regression failed.
2. Fix: label2id missing in text_classification_pipeline when preprocessor is passed in through args.
3. Fix: remove obsolete imports
4. Fix: incomplete modification
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10936431
master^2
| @@ -1,3 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | version https://git-lfs.github.com/spec/v1 | ||||
| oid sha256:3b38bfb5a851d35d5fba4d59eda926557666dbd62c70e3e3b24c22605e7d9c4a | |||||
| size 40771 | |||||
| oid sha256:dc16ad72e753f751360dab82878ec0a31190fb5125632d8f4698f6537fae79cb | |||||
| size 40819 | |||||
| @@ -79,12 +79,9 @@ class TextClassificationPipeline(Pipeline): | |||||
| 'sequence_length': sequence_length, | 'sequence_length': sequence_length, | ||||
| **kwargs | **kwargs | ||||
| }) | }) | ||||
| assert hasattr(self.preprocessor, 'id2label') | |||||
| self.id2label = self.preprocessor.id2label | |||||
| if self.id2label is None: | |||||
| logger.warn( | |||||
| 'The id2label mapping is None, will return original ids.' | |||||
| ) | |||||
| if hasattr(self.preprocessor, 'id2label'): | |||||
| self.id2label = self.preprocessor.id2label | |||||
| def forward(self, inputs: Dict[str, Any], | def forward(self, inputs: Dict[str, Any], | ||||
| **forward_params) -> Dict[str, Any]: | **forward_params) -> Dict[str, Any]: | ||||
| @@ -111,6 +108,9 @@ class TextClassificationPipeline(Pipeline): | |||||
| if self.model.__class__.__name__ == 'OfaForAllTasks': | if self.model.__class__.__name__ == 'OfaForAllTasks': | ||||
| return inputs | return inputs | ||||
| else: | else: | ||||
| if getattr(self, 'id2label', None) is None: | |||||
| logger.warn( | |||||
| 'The id2label mapping is None, will return original ids.') | |||||
| logits = inputs[OutputKeys.LOGITS].cpu().numpy() | logits = inputs[OutputKeys.LOGITS].cpu().numpy() | ||||
| if logits.shape[0] == 1: | if logits.shape[0] == 1: | ||||
| logits = logits[0] | logits = logits[0] | ||||
| @@ -126,7 +126,7 @@ class TextClassificationPipeline(Pipeline): | |||||
| probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() | probs = np.take_along_axis(probs, top_indices, axis=-1).tolist() | ||||
| def map_to_label(id): | def map_to_label(id): | ||||
| if self.id2label is not None: | |||||
| if getattr(self, 'id2label', None) is not None: | |||||
| if id in self.id2label: | if id in self.id2label: | ||||
| return self.id2label[id] | return self.id2label[id] | ||||
| elif str(id) in self.id2label: | elif str(id) in self.id2label: | ||||
| @@ -30,10 +30,6 @@ if TYPE_CHECKING: | |||||
| from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | ||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'nlp_base': [ | |||||
| 'NLPTokenizerPreprocessorBase', | |||||
| 'NLPBasePreprocessor', | |||||
| ], | |||||
| 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], | 'sentence_piece_preprocessor': ['SentencePiecePreprocessor'], | ||||
| 'bert_seq_cls_tokenizer': ['Tokenize'], | 'bert_seq_cls_tokenizer': ['Tokenize'], | ||||
| 'document_segmentation_preprocessor': | 'document_segmentation_preprocessor': | ||||
| @@ -119,6 +119,6 @@ class FaqQuestionAnsweringTransformersPreprocessor(Preprocessor): | |||||
| def batch_encode(self, sentence_list: list, max_length=None): | def batch_encode(self, sentence_list: list, max_length=None): | ||||
| if not max_length: | if not max_length: | ||||
| max_length = self.MAX_LEN | |||||
| max_length = self.max_len | |||||
| return self.tokenizer.batch_encode_plus( | return self.tokenizer.batch_encode_plus( | ||||
| sentence_list, padding=True, max_length=max_length) | sentence_list, padding=True, max_length=max_length) | ||||
| @@ -555,7 +555,7 @@ if __name__ == '__main__': | |||||
| nargs='*', | nargs='*', | ||||
| help='Run specified test suites(test suite files list split by space)') | help='Run specified test suites(test suite files list split by space)') | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| set_test_level(2) | |||||
| set_test_level(args.level) | |||||
| os.environ['REGRESSION_BASELINE'] = '1' | os.environ['REGRESSION_BASELINE'] = '1' | ||||
| logger.info(f'TEST LEVEL: {test_level()}') | logger.info(f'TEST LEVEL: {test_level()}') | ||||
| if not args.disable_profile: | if not args.disable_profile: | ||||