| @@ -9,46 +9,6 @@ import torch | |||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import transformers | import transformers | ||||
| from torch.nn.modules.loss import _Loss | from torch.nn.modules.loss import _Loss | ||||
| from torch.utils.data import Dataset | |||||
| from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||||
| class OFADataset(Dataset): | |||||
| def __init__(self, | |||||
| file_path: str, | |||||
| preprocessor: OfaPreprocessor, | |||||
| selected_id_keys: str, | |||||
| dtypes=None, | |||||
| separator='\t', | |||||
| cached_index=False, | |||||
| **kwargs): | |||||
| assert selected_id_keys is not None | |||||
| selected_col_ids = list() | |||||
| selected_col_keys = list() | |||||
| for id_key in selected_id_keys.split(','): | |||||
| id, key = id_key.split(':') | |||||
| selected_col_ids.append(id) | |||||
| selected_col_keys.append(key) | |||||
| self.dataset = OFAFileDataset( | |||||
| file_path=file_path, | |||||
| selected_col_ids=','.join(selected_col_ids), | |||||
| dtypes=dtypes, | |||||
| separator=separator, | |||||
| cached_index=cached_index) | |||||
| self.preprocessor = preprocessor | |||||
| def __len__(self): | |||||
| return len(self.dataset) | |||||
| def __getitem__(self, index): | |||||
| values = self.dataset[index] | |||||
| data = dict() | |||||
| for key, value in zip(self.selected_col_keys, values): | |||||
| data[key] = value | |||||
| return self.preprocessor(data) | |||||
| def construct_rdrop_sample(x): | def construct_rdrop_sample(x): | ||||