Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10385225 * fix token classification bugsmaster
| @@ -5,7 +5,6 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .modeling_bert import ( | |||
| BERT_PRETRAINED_MODEL_ARCHIVE_LIST, | |||
| BertForMaskedLM, | |||
| BertForMultipleChoice, | |||
| BertForNextSentencePrediction, | |||
| @@ -20,21 +19,14 @@ if TYPE_CHECKING: | |||
| load_tf_weights_in_bert, | |||
| ) | |||
| from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig | |||
| from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer | |||
| from .tokenization_bert_fast import BertTokenizerFast | |||
| from .configuration_bert import BertConfig, BertOnnxConfig | |||
| else: | |||
| _import_structure = { | |||
| 'configuration_bert': | |||
| ['BERT_PRETRAINED_CONFIG_ARCHIVE_MAP', 'BertConfig', 'BertOnnxConfig'], | |||
| 'tokenization_bert': | |||
| ['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'], | |||
| 'configuration_bert': ['BertConfig', 'BertOnnxConfig'], | |||
| } | |||
| _import_structure['tokenization_bert_fast'] = ['BertTokenizerFast'] | |||
| _import_structure['modeling_bert'] = [ | |||
| 'BERT_PRETRAINED_MODEL_ARCHIVE_LIST', | |||
| 'BertForMaskedLM', | |||
| 'BertForMultipleChoice', | |||
| 'BertForNextSentencePrediction', | |||
| @@ -1872,19 +1872,18 @@ class BertForTokenClassification(BertPreTrainedModel): | |||
| @add_start_docstrings_to_model_forward( | |||
| BERT_INPUTS_DOCSTRING.format('batch_size, sequence_length')) | |||
| def forward( | |||
| self, | |||
| input_ids=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| inputs_embeds=None, | |||
| labels=None, | |||
| output_attentions=None, | |||
| output_hidden_states=None, | |||
| return_dict=None, | |||
| ): | |||
| def forward(self, | |||
| input_ids=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| inputs_embeds=None, | |||
| labels=None, | |||
| output_attentions=None, | |||
| output_hidden_states=None, | |||
| return_dict=None, | |||
| **kwargs): | |||
| r""" | |||
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, | |||
| *optional*): | |||
| @@ -176,7 +176,7 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): | |||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) | |||
| @MODELS.register_module(Tasks.token_classification, module_name=Models.bert) | |||
| class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||
| class BertForTokenClassification(TokenClassification, BertPreTrainedModel): | |||
| """Bert token classification model. | |||
| Inherited from TokenClassificationBase. | |||
| @@ -187,7 +187,7 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||
| def __init__(self, config, model_dir): | |||
| if hasattr(config, 'base_model_prefix'): | |||
| BertForSequenceClassification.base_model_prefix = config.base_model_prefix | |||
| BertForTokenClassification.base_model_prefix = config.base_model_prefix | |||
| super().__init__(config, model_dir) | |||
| def build_base_model(self): | |||
| @@ -218,3 +218,28 @@ class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||
| output_hidden_states=output_hidden_states, | |||
| return_dict=return_dict, | |||
| **kwargs) | |||
| @classmethod | |||
| def _instantiate(cls, **kwargs): | |||
| """Instantiate the model. | |||
| @param kwargs: Input args. | |||
| model_dir: The model dir used to load the checkpoint and the label information. | |||
| num_labels: An optional arg to tell the model how many classes to initialize. | |||
| Method will call utils.parse_label_mapping if num_labels not supplied. | |||
| If num_labels is not found, the model will use the default setting (2 classes). | |||
| @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
| """ | |||
| model_dir = kwargs.get('model_dir') | |||
| num_labels = kwargs.get('num_labels') | |||
| if num_labels is None: | |||
| label2id = parse_label_mapping(model_dir) | |||
| if label2id is not None and len(label2id) > 0: | |||
| num_labels = len(label2id) | |||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
| return super(BertPreTrainedModel, | |||
| BertForTokenClassification).from_pretrained( | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @@ -40,7 +40,12 @@ class TokenClassificationPipeline(Pipeline): | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.id2label = getattr(model, 'id2label') | |||
| if hasattr(model, 'id2label'): | |||
| self.id2label = getattr(model, 'id2label') | |||
| else: | |||
| model_config = getattr(model, 'config') | |||
| self.id2label = getattr(model_config, 'id2label') | |||
| assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | |||
| 'as a parameter or make sure the preprocessor has the attribute.' | |||