|
|
@@ -239,7 +239,6 @@ class Pipeline(ABC): |
|
|
""" |
|
|
""" |
|
|
from torch.utils.data.dataloader import default_collate |
|
|
from torch.utils.data.dataloader import default_collate |
|
|
from modelscope.preprocessors import InputFeatures |
|
|
from modelscope.preprocessors import InputFeatures |
|
|
from text2sql_lgesql.utils.batch import Batch |
|
|
|
|
|
if isinstance(data, dict) or isinstance(data, Mapping): |
|
|
if isinstance(data, dict) or isinstance(data, Mapping): |
|
|
return type(data)( |
|
|
return type(data)( |
|
|
{k: self._collate_fn(v) |
|
|
{k: self._collate_fn(v) |
|
|
@@ -260,8 +259,6 @@ class Pipeline(ABC): |
|
|
return data |
|
|
return data |
|
|
elif isinstance(data, InputFeatures): |
|
|
elif isinstance(data, InputFeatures): |
|
|
return data |
|
|
return data |
|
|
elif isinstance(data, Batch): |
|
|
|
|
|
return data |
|
|
|
|
|
else: |
|
|
else: |
|
|
import mmcv |
|
|
import mmcv |
|
|
if isinstance(data, mmcv.parallel.data_container.DataContainer): |
|
|
if isinstance(data, mmcv.parallel.data_container.DataContainer): |
|
|
|