Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10385225 * fix token classification bugsmaster
| @@ -5,7 +5,6 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .modeling_bert import ( | from .modeling_bert import ( | ||||
| BERT_PRETRAINED_MODEL_ARCHIVE_LIST, | |||||
| BertForMaskedLM, | BertForMaskedLM, | ||||
| BertForMultipleChoice, | BertForMultipleChoice, | ||||
| BertForNextSentencePrediction, | BertForNextSentencePrediction, | ||||
| @@ -20,21 +19,14 @@ if TYPE_CHECKING: | |||||
| load_tf_weights_in_bert, | load_tf_weights_in_bert, | ||||
| ) | ) | ||||
| from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig | |||||
| from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer | |||||
| from .tokenization_bert_fast import BertTokenizerFast | |||||
| from .configuration_bert import BertConfig, BertOnnxConfig | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'configuration_bert': | |||||
| ['BERT_PRETRAINED_CONFIG_ARCHIVE_MAP', 'BertConfig', 'BertOnnxConfig'], | |||||
| 'tokenization_bert': | |||||
| ['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'], | |||||
| 'configuration_bert': ['BertConfig', 'BertOnnxConfig'], | |||||
| } | } | ||||
| _import_structure['tokenization_bert_fast'] = ['BertTokenizerFast'] | |||||
| _import_structure['modeling_bert'] = [ | _import_structure['modeling_bert'] = [ | ||||
| 'BERT_PRETRAINED_MODEL_ARCHIVE_LIST', | |||||
| 'BertForMaskedLM', | 'BertForMaskedLM', | ||||
| 'BertForMultipleChoice', | 'BertForMultipleChoice', | ||||
| 'BertForNextSentencePrediction', | 'BertForNextSentencePrediction', | ||||
| @@ -1872,19 +1872,18 @@ class BertForTokenClassification(BertPreTrainedModel): | |||||
| @add_start_docstrings_to_model_forward( | @add_start_docstrings_to_model_forward( | ||||
| BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | ||||
| def forward( | |||||
| self, | |||||
| input_ids=None, | |||||
| attention_mask=None, | |||||
| token_type_ids=None, | |||||
| position_ids=None, | |||||
| head_mask=None, | |||||
| inputs_embeds=None, | |||||
| labels=None, | |||||
| output_attentions=None, | |||||
| output_hidden_states=None, | |||||
| return_dict=None, | |||||
| ): | |||||
| def forward(self, | |||||
| input_ids=None, | |||||
| attention_mask=None, | |||||
| token_type_ids=None, | |||||
| position_ids=None, | |||||
| head_mask=None, | |||||
| inputs_embeds=None, | |||||
| labels=None, | |||||
| output_attentions=None, | |||||
| output_hidden_states=None, | |||||
| return_dict=None, | |||||
| **kwargs): | |||||
| r""" | r""" | ||||
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, | ||||
| *optional*): | *optional*): | ||||
| @@ -176,7 +176,7 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): | |||||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) | @MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) | ||||
| @MODELS.register_module(Tasks.token_classification, module_name=Models.bert) | @MODELS.register_module(Tasks.token_classification, module_name=Models.bert) | ||||
| class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||||
| class BertForTokenClassification(TokenClassification, BertPreTrainedModel): | |||||
| """Bert token classification model. | """Bert token classification model. | ||||
| Inherited from TokenClassificationBase. | Inherited from TokenClassificationBase. | ||||
| @@ -187,7 +187,7 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||||
| def __init__(self, config, model_dir): | def __init__(self, config, model_dir): | ||||
| if hasattr(config, 'base_model_prefix'): | if hasattr(config, 'base_model_prefix'): | ||||
| BertForSequenceClassification.base_model_prefix = config.base_model_prefix | |||||
| BertForTokenClassification.base_model_prefix = config.base_model_prefix | |||||
| super().__init__(config, model_dir) | super().__init__(config, model_dir) | ||||
| def build_base_model(self): | def build_base_model(self): | ||||
| @@ -218,3 +218,28 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||||
| output_hidden_states=output_hidden_states, | output_hidden_states=output_hidden_states, | ||||
| return_dict=return_dict, | return_dict=return_dict, | ||||
| **kwargs) | **kwargs) | ||||
| @classmethod | |||||
| def _instantiate(cls, **kwargs): | |||||
| """Instantiate the model. | |||||
| @param kwargs: Input args. | |||||
| model_dir: The model dir used to load the checkpoint and the label information. | |||||
| num_labels: An optional arg to tell the model how many classes to initialize. | |||||
| Method will call utils.parse_label_mapping if num_labels not supplied. | |||||
| If num_labels is not found, the model will use the default setting (2 classes). | |||||
| @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||||
| """ | |||||
| model_dir = kwargs.get('model_dir') | |||||
| num_labels = kwargs.get('num_labels') | |||||
| if num_labels is None: | |||||
| label2id = parse_label_mapping(model_dir) | |||||
| if label2id is not None and len(label2id) > 0: | |||||
| num_labels = len(label2id) | |||||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||||
| return super(BertPreTrainedModel, | |||||
| BertForTokenClassification).from_pretrained( | |||||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||||
| model_dir=kwargs.get('model_dir'), | |||||
| **model_args) | |||||
| @@ -40,7 +40,12 @@ class TokenClassificationPipeline(Pipeline): | |||||
| sequence_length=kwargs.pop('sequence_length', 128)) | sequence_length=kwargs.pop('sequence_length', 128)) | ||||
| model.eval() | model.eval() | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | super().__init__(model=model, preprocessor=preprocessor, **kwargs) | ||||
| self.id2label = getattr(model, 'id2label') | |||||
| if hasattr(model, 'id2label'): | |||||
| self.id2label = getattr(model, 'id2label') | |||||
| else: | |||||
| model_config = getattr(model, 'config') | |||||
| self.id2label = getattr(model_config, 'id2label') | |||||
| assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | ||||
| 'as a parameter or make sure the preprocessor has the attribute.' | 'as a parameter or make sure the preprocessor has the attribute.' | ||||