From be6b82cd6de7f8594bcaf9b4dc7098958b8a5d43 Mon Sep 17 00:00:00 2001 From: "piaoyu.lxy" Date: Thu, 11 Aug 2022 20:11:06 +0800 Subject: [PATCH] [to #42322933] fix modelscope/pipelines/base.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 删除base.py _collate_fn()函数中关于text2sql模型的相关代码,挪到ConversationalTextToSqlPipeline相关的代码里 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9730687 --- modelscope/pipelines/base.py | 3 --- .../pipelines/nlp/conversational_text_to_sql_pipeline.py | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 37d6f1e3..b1d82557 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -239,7 +239,6 @@ class Pipeline(ABC): """ from torch.utils.data.dataloader import default_collate from modelscope.preprocessors import InputFeatures - from text2sql_lgesql.utils.batch import Batch if isinstance(data, dict) or isinstance(data, Mapping): return type(data)( {k: self._collate_fn(v) @@ -260,8 +259,6 @@ class Pipeline(ABC): return data elif isinstance(data, InputFeatures): return data - elif isinstance(data, Batch): - return data else: import mmcv if isinstance(data, mmcv.parallel.data_container.DataContainer): diff --git a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py index 875c47fd..399dad5a 100644 --- a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py +++ b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py @@ -64,3 +64,6 @@ class ConversationalTextToSqlPipeline(Pipeline): sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) result = {OutputKeys.TEXT: sql} return result + + def _collate_fn(self, data): + return data