|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from typing import Any, Dict
-
- import json
- import numpy as np
-
- from modelscope.metainfo import Models
- from modelscope.models import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import Tasks
-
- __all__ = ['BertForSequenceClassification']
-
-
- @MODELS.register_module(Tasks.text_classification, module_name=Models.bert)
- class BertForSequenceClassification(TorchModel):
-
- def __init__(self, model_dir: str, *args, **kwargs):
- # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs)
- # Predictor.__init__(self, *args, **kwargs)
- """initialize the sequence classification model from the `model_dir` path.
-
- Args:
- model_dir (str): the model path.
- """
-
- super().__init__(model_dir, *args, **kwargs)
- import torch
- from easynlp.appzoo import SequenceClassification
- from easynlp.core.predictor import get_model_predictor
- self.model = get_model_predictor(
- model_dir=self.model_dir,
- model_cls=SequenceClassification,
- input_keys=[('input_ids', torch.LongTensor),
- ('attention_mask', torch.LongTensor),
- ('token_type_ids', torch.LongTensor)],
- output_keys=['predictions', 'probabilities', 'logits'])
-
- self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
- with open(self.label_path) as f:
- self.label_mapping = json.load(f)
- self.id2label = {idx: name for name, idx in self.label_mapping.items()}
-
- def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
- """return the result by the model
-
- Args:
- input (Dict[str, Any]): the preprocessed data
-
- Returns:
- Dict[str, np.ndarray]: results
- Example:
- {
- 'predictions': array([1]), # lable 0-negative 1-positive
- 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
- 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
- }
- """
- return self.model.predict(input)
-
- def postprocess(self, inputs: Dict[str, np.ndarray],
- **kwargs) -> Dict[str, np.ndarray]:
- # N x num_classes
- probs = inputs['probabilities']
- result = {
- 'probs': probs,
- }
-
- return result
|