|
|
|
@@ -9,46 +9,6 @@ import torch |
|
|
|
import torch.nn.functional as F |
|
|
|
import transformers |
|
|
|
from torch.nn.modules.loss import _Loss |
|
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
|
|
from modelscope.preprocessors.multi_modal import OfaPreprocessor |
|
|
|
|
|
|
|
|
|
|
|
class OFADataset(Dataset): |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
file_path: str, |
|
|
|
preprocessor: OfaPreprocessor, |
|
|
|
selected_id_keys: str, |
|
|
|
dtypes=None, |
|
|
|
separator='\t', |
|
|
|
cached_index=False, |
|
|
|
**kwargs): |
|
|
|
assert selected_id_keys is not None |
|
|
|
selected_col_ids = list() |
|
|
|
selected_col_keys = list() |
|
|
|
for id_key in selected_id_keys.split(','): |
|
|
|
id, key = id_key.split(':') |
|
|
|
selected_col_ids.append(id) |
|
|
|
selected_col_keys.append(key) |
|
|
|
|
|
|
|
self.dataset = OFAFileDataset( |
|
|
|
file_path=file_path, |
|
|
|
selected_col_ids=','.join(selected_col_ids), |
|
|
|
dtypes=dtypes, |
|
|
|
separator=separator, |
|
|
|
cached_index=cached_index) |
|
|
|
self.preprocessor = preprocessor |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.dataset) |
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
values = self.dataset[index] |
|
|
|
data = dict() |
|
|
|
for key, value in zip(self.selected_col_keys, values): |
|
|
|
data[key] = value |
|
|
|
return self.preprocessor(data) |
|
|
|
|
|
|
|
|
|
|
|
def construct_rdrop_sample(x): |
|
|
|
|