Maas新增FAQ问答模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9797053
master
| @@ -49,8 +49,8 @@ def handle_http_response(response, logger, cookies, model_id): | |||
| except HTTPError: | |||
| if cookies is None: # code in [403] and | |||
| logger.error( | |||
| f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be private. \ | |||
| Please login first.') | |||
| f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ | |||
| private. Please login first.') | |||
| raise | |||
| @@ -138,6 +138,7 @@ class Pipelines(object): | |||
| dialog_state_tracking = 'dialog-state-tracking' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| text_error_correction = 'text-error-correction' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| # audio tasks | |||
| @@ -220,6 +221,7 @@ class Preprocessors(object): | |||
| text_error_correction = 'text-error-correction' | |||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
| fill_mask = 'fill-mask' | |||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| # audio preprocessor | |||
| @@ -6,7 +6,8 @@ import torch | |||
| def point_form(boxes: torch.Tensor) -> torch.Tensor: | |||
| """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data. | |||
| """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form \ | |||
| ground truth data. | |||
| Args: | |||
| boxes: center-size default boxes from priorbox layers. | |||
| @@ -21,6 +21,7 @@ if TYPE_CHECKING: | |||
| from .task_models.task_model import SingleBackboneTaskModelBase | |||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
| from .gpt3 import GPT3ForTextGeneration | |||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
| else: | |||
| _import_structure = { | |||
| @@ -44,6 +45,7 @@ else: | |||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
| 'gpt3': ['GPT3ForTextGeneration'], | |||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,249 @@ | |||
| import math | |||
| import os | |||
| from collections import namedtuple | |||
| from typing import Dict | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch import Tensor | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.structbert import SbertConfig, SbertModel | |||
| from modelscope.models.nlp.task_models.task_model import BaseTaskModel | |||
| from modelscope.utils.config import Config, ConfigFields | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['SbertForFaqQuestionAnswering'] | |||
| class SbertForFaqQuestionAnsweringBase(BaseTaskModel): | |||
| """base class for faq models | |||
| """ | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| super(SbertForFaqQuestionAnsweringBase, | |||
| self).__init__(model_dir, *args, **kwargs) | |||
| backbone_cfg = SbertConfig.from_pretrained(model_dir) | |||
| self.bert = SbertModel(backbone_cfg) | |||
| model_config = Config.from_file( | |||
| os.path.join(model_dir, | |||
| ModelFile.CONFIGURATION)).get(ConfigFields.model, {}) | |||
| metric = model_config.get('metric', 'cosine') | |||
| pooling_method = model_config.get('pooling', 'avg') | |||
| Arg = namedtuple('args', [ | |||
| 'metrics', 'proj_hidden_size', 'hidden_size', 'dropout', 'pooling' | |||
| ]) | |||
| args = Arg( | |||
| metrics=metric, | |||
| proj_hidden_size=self.bert.config.hidden_size, | |||
| hidden_size=self.bert.config.hidden_size, | |||
| dropout=0.0, | |||
| pooling=pooling_method) | |||
| self.metrics_layer = MetricsLayer(args) | |||
| self.pooling = PoolingLayer(args) | |||
| def _get_onehot_labels(self, labels, support_size, num_cls): | |||
| labels_ = labels.view(support_size, 1) | |||
| target_oh = torch.zeros(support_size, num_cls).to(labels) | |||
| target_oh.scatter_(dim=1, index=labels_, value=1) | |||
| return target_oh.view(support_size, num_cls).float() | |||
| def forward_sentence_embedding(self, inputs: Dict[str, Tensor]): | |||
| input_ids = inputs['input_ids'] | |||
| input_mask = inputs['attention_mask'] | |||
| if not isinstance(input_ids, Tensor): | |||
| input_ids = torch.IntTensor(input_ids) | |||
| if not isinstance(input_mask, Tensor): | |||
| input_mask = torch.IntTensor(input_mask) | |||
| rst = self.bert(input_ids, input_mask) | |||
| last_hidden_states = rst.last_hidden_state | |||
| if len(input_mask.shape) == 2: | |||
| input_mask = input_mask.unsqueeze(-1) | |||
| pooled_representation = self.pooling(last_hidden_states, input_mask) | |||
| return pooled_representation | |||
| @MODELS.register_module( | |||
| Tasks.faq_question_answering, module_name=Models.structbert) | |||
| class SbertForFaqQuestionAnswering(SbertForFaqQuestionAnsweringBase): | |||
| _backbone_prefix = '' | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| assert not self.training | |||
| query = input['query'] | |||
| support = input['support'] | |||
| if isinstance(query, list): | |||
| query = torch.stack(query) | |||
| if isinstance(support, list): | |||
| support = torch.stack(support) | |||
| n_query = query.shape[0] | |||
| n_support = support.shape[0] | |||
| query_mask = torch.ne(query, 0).view([n_query, -1]) | |||
| support_mask = torch.ne(support, 0).view([n_support, -1]) | |||
| support_labels = input['support_labels'] | |||
| num_cls = torch.max(support_labels) + 1 | |||
| onehot_labels = self._get_onehot_labels(support_labels, n_support, | |||
| num_cls) | |||
| input_ids = torch.cat([query, support]) | |||
| input_mask = torch.cat([query_mask, support_mask], dim=0) | |||
| pooled_representation = self.forward_sentence_embedding({ | |||
| 'input_ids': | |||
| input_ids, | |||
| 'attention_mask': | |||
| input_mask | |||
| }) | |||
| z_query = pooled_representation[:n_query] | |||
| z_support = pooled_representation[n_query:] | |||
| cls_n_support = torch.sum(onehot_labels, dim=-2) + 1e-5 | |||
| protos = torch.matmul(onehot_labels.transpose(0, 1), | |||
| z_support) / cls_n_support.unsqueeze(-1) | |||
| scores = self.metrics_layer(z_query, protos).view([n_query, num_cls]) | |||
| if self.metrics_layer.name == 'relation': | |||
| scores = torch.sigmoid(scores) | |||
| return {'scores': scores} | |||
| activations = { | |||
| 'relu': F.relu, | |||
| 'tanh': torch.tanh, | |||
| 'linear': lambda x: x, | |||
| } | |||
| activation_coeffs = { | |||
| 'relu': math.sqrt(2), | |||
| 'tanh': 5 / 3, | |||
| 'linear': 1., | |||
| } | |||
| class LinearProjection(nn.Module): | |||
| def __init__(self, | |||
| in_features, | |||
| out_features, | |||
| activation='linear', | |||
| bias=True): | |||
| super().__init__() | |||
| self.activation = activations[activation] | |||
| activation_coeff = activation_coeffs[activation] | |||
| linear = nn.Linear(in_features, out_features, bias=bias) | |||
| nn.init.normal_( | |||
| linear.weight, std=math.sqrt(1. / in_features) * activation_coeff) | |||
| if bias: | |||
| nn.init.zeros_(linear.bias) | |||
| self.model = nn.utils.weight_norm(linear) | |||
| def forward(self, x): | |||
| return self.activation(self.model(x)) | |||
| class RelationModule(nn.Module): | |||
| def __init__(self, args): | |||
| super(RelationModule, self).__init__() | |||
| input_size = args.proj_hidden_size * 4 | |||
| self.prediction = torch.nn.Sequential( | |||
| LinearProjection( | |||
| input_size, args.proj_hidden_size * 4, activation='relu'), | |||
| nn.Dropout(args.dropout), | |||
| LinearProjection(args.proj_hidden_size * 4, 1)) | |||
| def forward(self, query, protos): | |||
| n_cls = protos.shape[0] | |||
| n_query = query.shape[0] | |||
| protos = protos.unsqueeze(0).repeat(n_query, 1, 1) | |||
| query = query.unsqueeze(1).repeat(1, n_cls, 1) | |||
| input_feat = torch.cat( | |||
| [query, protos, (protos - query).abs(), query * protos], dim=-1) | |||
| dists = self.prediction(input_feat) # [bsz,n_query,n_cls,1] | |||
| return dists.squeeze(-1) | |||
| class MetricsLayer(nn.Module): | |||
| def __init__(self, args): | |||
| super(MetricsLayer, self).__init__() | |||
| self.args = args | |||
| assert args.metrics in ('relation', 'cosine') | |||
| if args.metrics == 'relation': | |||
| self.relation_net = RelationModule(args) | |||
| @property | |||
| def name(self): | |||
| return self.args.metrics | |||
| def forward(self, query, protos): | |||
| """ query : [bsz, n_query, dim] | |||
| support : [bsz, n_query, n_cls, dim] | [bsz, n_cls, dim] | |||
| """ | |||
| if self.args.metrics == 'cosine': | |||
| supervised_dists = self.cosine_similarity(query, protos) | |||
| if self.training: | |||
| supervised_dists *= 5 | |||
| elif self.args.metrics in ('relation', ): | |||
| supervised_dists = self.relation_net(query, protos) | |||
| else: | |||
| raise NotImplementedError | |||
| return supervised_dists | |||
| def cosine_similarity(self, x, y): | |||
| # x=[bsz, n_query, dim] | |||
| # y=[bsz, n_cls, dim] | |||
| n_query = x.shape[0] | |||
| n_cls = y.shape[0] | |||
| dim = x.shape[-1] | |||
| x = x.unsqueeze(1).expand([n_query, n_cls, dim]) | |||
| y = y.unsqueeze(0).expand([n_query, n_cls, dim]) | |||
| return F.cosine_similarity(x, y, -1) | |||
| class AveragePooling(nn.Module): | |||
| def forward(self, x, mask, dim=1): | |||
| return torch.sum( | |||
| x * mask.float(), dim=dim) / torch.sum( | |||
| mask.float(), dim=dim) | |||
| class AttnPooling(nn.Module): | |||
| def __init__(self, input_size, hidden_size=None, output_size=None): | |||
| super().__init__() | |||
| self.input_proj = nn.Sequential( | |||
| LinearProjection(input_size, hidden_size), nn.Tanh(), | |||
| LinearProjection(hidden_size, 1, bias=False)) | |||
| self.output_proj = LinearProjection( | |||
| input_size, output_size) if output_size else lambda x: x | |||
| def forward(self, x, mask): | |||
| score = self.input_proj(x) | |||
| score = score * mask.float() + -1e4 * (1. - mask.float()) | |||
| score = F.softmax(score, dim=1) | |||
| features = self.output_proj(x) | |||
| return torch.matmul(score.transpose(1, 2), features).squeeze(1) | |||
| class PoolingLayer(nn.Module): | |||
| def __init__(self, args): | |||
| super(PoolingLayer, self).__init__() | |||
| if args.pooling == 'attn': | |||
| self.pooling = AttnPooling(args.proj_hidden_size, | |||
| args.proj_hidden_size, | |||
| args.proj_hidden_size) | |||
| elif args.pooling == 'avg': | |||
| self.pooling = AveragePooling() | |||
| else: | |||
| raise NotImplementedError(args.pooling) | |||
| def forward(self, x, mask): | |||
| return self.pooling(x, mask) | |||
| @@ -7,6 +7,7 @@ class OutputKeys(object): | |||
| LOSS = 'loss' | |||
| LOGITS = 'logits' | |||
| SCORES = 'scores' | |||
| SCORE = 'score' | |||
| LABEL = 'label' | |||
| LABELS = 'labels' | |||
| INPUT_IDS = 'input_ids' | |||
| @@ -504,6 +505,16 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
| # { | |||
| # 'output': [ | |||
| # [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, | |||
| # {'label': '13421097', 'score': 2.2825044965202324e-08}], | |||
| # [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, | |||
| # {'label': '13421097', 'score': 2.75914817393641e-06}], | |||
| # [{'label': '1000012000', 'score': 0.910681426525116}, {'label': '6527856', 'score': 0.0005046309670433402}, | |||
| # {'label': '13421097', 'score': 2.75914817393641e-06}]] | |||
| # } | |||
| Tasks.faq_question_answering: [OutputKeys.OUTPUT], | |||
| # image person reid result for single sample | |||
| # { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| @@ -129,6 +129,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_convnextTiny_ocr-recognition-general_damo'), | |||
| Tasks.skin_retouching: (Pipelines.skin_retouching, | |||
| 'damo/cv_unet_skin-retouching'), | |||
| Tasks.faq_question_answering: | |||
| (Pipelines.faq_question_answering, | |||
| 'damo/nlp_structbert_faq-question-answering_chinese-base'), | |||
| Tasks.crowd_counting: (Pipelines.crowd_counting, | |||
| 'damo/cv_hrnet_crowd-counting_dcanet'), | |||
| Tasks.video_single_object_tracking: | |||
| @@ -218,7 +221,6 @@ def pipeline(task: str = None, | |||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||
| model = normalize_model_input(model, model_revision) | |||
| if pipeline_name is None: | |||
| # get default pipeline for this task | |||
| if isinstance(model, str) \ | |||
| @@ -20,6 +20,7 @@ if TYPE_CHECKING: | |||
| from .summarization_pipeline import SummarizationPipeline | |||
| from .text_classification_pipeline import TextClassificationPipeline | |||
| from .text_error_correction_pipeline import TextErrorCorrectionPipeline | |||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
| else: | |||
| _import_structure = { | |||
| @@ -44,7 +45,8 @@ else: | |||
| 'translation_pipeline': ['TranslationPipeline'], | |||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | |||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'] | |||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,76 @@ | |||
| from typing import Any, Dict, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForFaqQuestionAnswering | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['FaqQuestionAnsweringPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.faq_question_answering, module_name=Pipelines.faq_question_answering) | |||
| class FaqQuestionAnsweringPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[str, SbertForFaqQuestionAnswering], | |||
| preprocessor: FaqQuestionAnsweringPreprocessor = None, | |||
| **kwargs): | |||
| model = model if isinstance( | |||
| model, | |||
| SbertForFaqQuestionAnswering) else Model.from_pretrained(model) | |||
| model.eval() | |||
| if preprocessor is None: | |||
| preprocessor = FaqQuestionAnsweringPreprocessor( | |||
| model.model_dir, **kwargs) | |||
| self.preprocessor = preprocessor | |||
| super(FaqQuestionAnsweringPipeline, self).__init__( | |||
| model=model, preprocessor=preprocessor, **kwargs) | |||
| def _sanitize_parameters(self, **pipeline_parameters): | |||
| return pipeline_parameters, pipeline_parameters, pipeline_parameters | |||
| def get_sentence_embedding(self, inputs, max_len=None): | |||
| inputs = self.preprocessor.batch_encode(inputs, max_length=max_len) | |||
| sentence_vecs = self.model.forward_sentence_embedding(inputs) | |||
| sentence_vecs = sentence_vecs.detach().tolist() | |||
| return sentence_vecs | |||
| def forward(self, inputs: [list, Dict[str, Any]], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return self.model(inputs) | |||
| def postprocess(self, inputs: [list, Dict[str, Any]], | |||
| **postprocess_params) -> Dict[str, Any]: | |||
| scores = inputs['scores'] | |||
| labels = [] | |||
| for item in scores: | |||
| tmplabels = [ | |||
| self.preprocessor.get_label(label_id) | |||
| for label_id in range(len(item)) | |||
| ] | |||
| labels.append(tmplabels) | |||
| predictions = [] | |||
| for tmp_scores, tmp_labels in zip(scores.tolist(), labels): | |||
| prediction = [] | |||
| for score, label in zip(tmp_scores, tmp_labels): | |||
| prediction.append({ | |||
| OutputKeys.LABEL: label, | |||
| OutputKeys.SCORE: score | |||
| }) | |||
| predictions.append( | |||
| list( | |||
| sorted( | |||
| prediction, | |||
| key=lambda d: d[OutputKeys.SCORE], | |||
| reverse=True))) | |||
| return {OutputKeys.OUTPUT: predictions} | |||
| @@ -21,7 +21,8 @@ if TYPE_CHECKING: | |||
| SingleSentenceClassificationPreprocessor, | |||
| PairSentenceClassificationPreprocessor, | |||
| FillMaskPreprocessor, ZeroShotClassificationPreprocessor, | |||
| NERPreprocessor, TextErrorCorrectionPreprocessor) | |||
| NERPreprocessor, TextErrorCorrectionPreprocessor, | |||
| FaqQuestionAnsweringPreprocessor) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor) | |||
| @@ -48,7 +49,8 @@ else: | |||
| 'SingleSentenceClassificationPreprocessor', | |||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor' | |||
| 'TextErrorCorrectionPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor' | |||
| ], | |||
| 'space': [ | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| @@ -5,10 +5,12 @@ import uuid | |||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| from transformers import AutoTokenizer | |||
| from modelscope.metainfo import Models, Preprocessors | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.config import ConfigFields | |||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys | |||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | |||
| from modelscope.utils.type_assert import type_assert | |||
| @@ -21,7 +23,7 @@ __all__ = [ | |||
| 'PairSentenceClassificationPreprocessor', | |||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor' | |||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor' | |||
| ] | |||
| @@ -645,3 +647,86 @@ class TextErrorCorrectionPreprocessor(Preprocessor): | |||
| sample = dict() | |||
| sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} | |||
| return sample | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.faq_question_answering_preprocessor) | |||
| class FaqQuestionAnsweringPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| super(FaqQuestionAnsweringPreprocessor, self).__init__( | |||
| model_dir, pair=False, mode=ModeKeys.INFERENCE, **kwargs) | |||
| import os | |||
| from transformers import BertTokenizer | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
| preprocessor_config = Config.from_file( | |||
| os.path.join(model_dir, ModelFile.CONFIGURATION)).get( | |||
| ConfigFields.preprocessor, {}) | |||
| self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) | |||
| self.label_dict = None | |||
| def pad(self, samples, max_len): | |||
| result = [] | |||
| for sample in samples: | |||
| pad_len = max_len - len(sample[:max_len]) | |||
| result.append(sample[:max_len] | |||
| + [self.tokenizer.pad_token_id] * pad_len) | |||
| return result | |||
| def set_label_dict(self, label_dict): | |||
| self.label_dict = label_dict | |||
| def get_label(self, label_id): | |||
| assert self.label_dict is not None and label_id < len(self.label_dict) | |||
| return self.label_dict[label_id] | |||
| def encode_plus(self, text): | |||
| return [ | |||
| self.tokenizer.cls_token_id | |||
| ] + self.tokenizer.convert_tokens_to_ids( | |||
| self.tokenizer.tokenize(text)) + [self.tokenizer.sep_token_id] | |||
| @type_assert(object, Dict) | |||
| def __call__(self, data: Dict[str, Any], | |||
| **preprocessor_param) -> Dict[str, Any]: | |||
| TMP_MAX_LEN = preprocessor_param.get('max_seq_length', self.MAX_LEN) | |||
| queryset = data['query_set'] | |||
| if not isinstance(queryset, list): | |||
| queryset = [queryset] | |||
| supportset = data['support_set'] | |||
| supportset = sorted(supportset, key=lambda d: d['label']) | |||
| queryset_tokenized = [self.encode_plus(text) for text in queryset] | |||
| supportset_tokenized = [ | |||
| self.encode_plus(item['text']) for item in supportset | |||
| ] | |||
| max_len = max( | |||
| [len(seq) for seq in queryset_tokenized + supportset_tokenized]) | |||
| max_len = min(TMP_MAX_LEN, max_len) | |||
| queryset_padded = self.pad(queryset_tokenized, max_len) | |||
| supportset_padded = self.pad(supportset_tokenized, max_len) | |||
| supportset_labels_ori = [item['label'] for item in supportset] | |||
| label_dict = [] | |||
| for label in supportset_labels_ori: | |||
| if label not in label_dict: | |||
| label_dict.append(label) | |||
| self.set_label_dict(label_dict) | |||
| supportset_labels_ids = [ | |||
| label_dict.index(label) for label in supportset_labels_ori | |||
| ] | |||
| return { | |||
| 'query': queryset_padded, | |||
| 'support': supportset_padded, | |||
| 'support_labels': supportset_labels_ids | |||
| } | |||
| def batch_encode(self, sentence_list: list, max_length=None): | |||
| if not max_length: | |||
| max_length = self.MAX_LEN | |||
| return self.tokenizer.batch_encode_plus( | |||
| sentence_list, padding=True, max_length=max_length) | |||
| @@ -95,6 +95,7 @@ class NLPTasks(object): | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| backbone = 'backbone' | |||
| text_error_correction = 'text-error-correction' | |||
| faq_question_answering = 'faq-question-answering' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| @@ -0,0 +1,85 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import numpy as np | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForFaqQuestionAnswering | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import FaqQuestionAnsweringPipeline | |||
| from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class FaqQuestionAnsweringTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_faq-question-answering_chinese-base' | |||
| param = { | |||
| 'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'], | |||
| 'support_set': [{ | |||
| 'text': '卖品代金券怎么用', | |||
| 'label': '6527856' | |||
| }, { | |||
| 'text': '怎么使用优惠券', | |||
| 'label': '6527856' | |||
| }, { | |||
| 'text': '这个可以一起领吗', | |||
| 'label': '1000012000' | |||
| }, { | |||
| 'text': '付款时送的优惠券哪里领', | |||
| 'label': '1000012000' | |||
| }, { | |||
| 'text': '购物等级怎么长', | |||
| 'label': '13421097' | |||
| }, { | |||
| 'text': '购物等级二心', | |||
| 'label': '13421097' | |||
| }] | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = FaqQuestionAnsweringPreprocessor(cache_path) | |||
| model = SbertForFaqQuestionAnswering(cache_path) | |||
| model.load_checkpoint(cache_path) | |||
| pipeline_ins = FaqQuestionAnsweringPipeline( | |||
| model, preprocessor=preprocessor) | |||
| result = pipeline_ins(self.param) | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = FaqQuestionAnsweringPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.faq_question_answering, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| result = pipeline_ins(self.param) | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.faq_question_answering, model=self.model_id) | |||
| result = pipeline_ins(self.param) | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.faq_question_answering) | |||
| print(pipeline_ins(self.param, max_seq_length=20)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_sentence_embedding(self): | |||
| pipeline_ins = pipeline(task=Tasks.faq_question_answering) | |||
| sentence_vec = pipeline_ins.get_sentence_embedding( | |||
| ['今天星期六', '明天星期几明天星期几']) | |||
| print(np.shape(sentence_vec)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||