diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 37d6f1e3..b1d82557 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -239,7 +239,6 @@ class Pipeline(ABC): """ from torch.utils.data.dataloader import default_collate from modelscope.preprocessors import InputFeatures - from text2sql_lgesql.utils.batch import Batch if isinstance(data, dict) or isinstance(data, Mapping): return type(data)( {k: self._collate_fn(v) @@ -260,8 +259,6 @@ class Pipeline(ABC): return data elif isinstance(data, InputFeatures): return data - elif isinstance(data, Batch): - return data else: import mmcv if isinstance(data, mmcv.parallel.data_container.DataContainer): diff --git a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py index 875c47fd..399dad5a 100644 --- a/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py +++ b/modelscope/pipelines/nlp/conversational_text_to_sql_pipeline.py @@ -64,3 +64,6 @@ class ConversationalTextToSqlPipeline(Pipeline): sql = Example.evaluator.obtain_sql(inputs['predict'][0], inputs['db']) result = {OutputKeys.TEXT: sql} return result + + def _collate_fn(self, data): + return data