| @@ -148,7 +148,7 @@ class BeamSearch(Search): | |||
| scores_buf = top_prediction[0] | |||
| indices_buf = top_prediction[1] | |||
| # Project back into relative indices and beams | |||
| beams_buf = indices_buf // vocab_size | |||
| beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor') | |||
| indices_buf = indices_buf.fmod(vocab_size) | |||
| # At this point, beams_buf and indices_buf are single-dim and contain relative indices | |||
| @@ -385,12 +385,7 @@ class SequenceGenerator(nn.Module): | |||
| attn = torch.empty(bsz * beam_size, | |||
| avg_attn_scores.size(1), | |||
| max_len + 2).to(scores) | |||
| # print("+++++++ debug attention shape +++++++") | |||
| # print("attn", attn.shape) | |||
| # print("avg_attn_scores", avg_attn_scores.shape) | |||
| attn[:, :, step + 1].copy_(avg_attn_scores) | |||
| # print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) | |||
| # print("attn", attn.shape) | |||
| scores = scores.type_as(lprobs) | |||
| eos_bbsz_idx = torch.empty(0).to( | |||
| @@ -403,7 +398,8 @@ class SequenceGenerator(nn.Module): | |||
| if self.should_set_src_lengths: | |||
| self.search.set_src_lengths(src_lengths) | |||
| if self.repeat_ngram_blocker is not None: | |||
| if self.repeat_ngram_blocker is not None and step > prefix_tokens.size( | |||
| 1): | |||
| lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, | |||
| beam_size, step) | |||
| @@ -415,7 +411,6 @@ class SequenceGenerator(nn.Module): | |||
| tokens[:, :step + 1], | |||
| original_batch_idxs, | |||
| ) | |||
| # cand_bbsz_idx contains beam indices for the top candidate | |||
| # hypotheses, with a range of values: [0, bsz*beam_size), | |||
| # and dimensions: [bsz, cand_size] | |||
| @@ -671,7 +666,7 @@ class SequenceGenerator(nn.Module): | |||
| cum_unfin.append(prev) | |||
| cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) | |||
| unfin_idx = bbsz_idx // beam_size | |||
| unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode='floor') | |||
| sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) | |||
| # Create a set of "{sent}{unfin_idx}", where | |||
| @@ -114,7 +114,8 @@ def make_image_bucket_position(bucket_size, num_relative_distance): | |||
| """ | |||
| coords_h = torch.arange(bucket_size) | |||
| coords_w = torch.arange(bucket_size) | |||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w], | |||
| indexing='ij')) # 2, Wh, Ww | |||
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
| relative_coords = coords_flatten[:, :, None] - \ | |||
| coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | |||
| @@ -7,7 +7,7 @@ OFA_TASK_KEY_MAPPING = { | |||
| Tasks.summarization: OutputKeys.TEXT, | |||
| Tasks.visual_question_answering: OutputKeys.TEXT, | |||
| Tasks.visual_grounding: OutputKeys.BOXES, | |||
| Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS), | |||
| Tasks.text_classification: OutputKeys.LABELS, | |||
| Tasks.image_classification: OutputKeys.LABELS, | |||
| Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS), | |||
| Tasks.visual_entailment: OutputKeys.LABELS, | |||
| } | |||
| @@ -127,10 +127,23 @@ class OfaForAllTasks(TorchModel): | |||
| return input | |||
| def _text_gen_inference(self, input): | |||
| import pdb | |||
| pdb.set_trace() | |||
| input = move_to_device(input, self._device) | |||
| gen_output = self.generator.generate([self.model], input) | |||
| gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | |||
| result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) | |||
| if 'prefix_tokens' in input: | |||
| gen_output = self.generator.generate( | |||
| [self.model], input, prefix_tokens=input['prefix_tokens']) | |||
| else: | |||
| gen_output = self.generator.generate([self.model], input) | |||
| gen_l = list() | |||
| for i in range(len(gen_output)): | |||
| if 'prefix_tokens' in input: | |||
| prefix_tokens = input['prefix_tokens'] | |||
| gen_l.append( | |||
| gen_output[i][0]['tokens'][len(prefix_tokens[i]):]) | |||
| else: | |||
| gen_l.append(gen_output[i][0]['tokens']) | |||
| result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) | |||
| # text generation tasks have no score | |||
| ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} | |||
| if self.cfg.task.endswith('classification'): | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from io import BytesIO | |||
| from typing import Any, Dict, List, Union | |||
| import torch | |||
| @@ -8,6 +9,7 @@ from PIL import Image | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModelFile, Tasks | |||
| from .base import Preprocessor | |||
| @@ -71,20 +73,32 @@ class OfaPreprocessor(Preprocessor): | |||
| data[key] = item | |||
| return data | |||
| def _compatible_with_pretrain(self, data): | |||
| if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | |||
| image = load_image(data['image']) | |||
| img_buffer = BytesIO() | |||
| image.save(img_buffer, format='JPEG') | |||
| data['image'] = Image.open(img_buffer) | |||
| return data | |||
| def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, | |||
| **kwargs) -> Dict[str, Any]: | |||
| if isinstance(input, dict): | |||
| data = input | |||
| else: | |||
| data = self._build_dict(input) | |||
| data = self._compatible_with_pretrain(data) | |||
| sample = self.preprocess(data) | |||
| str_data = dict() | |||
| for k, v in data.items(): | |||
| str_data[k] = str(v) | |||
| sample['sample'] = str_data | |||
| return collate_fn([sample], | |||
| pad_idx=self.tokenizer.pad_token_id, | |||
| eos_idx=self.tokenizer.eos_token_id) | |||
| if kwargs.get('no_collate', None): | |||
| return sample | |||
| else: | |||
| return collate_fn([sample], | |||
| pad_idx=self.tokenizer.pad_token_id, | |||
| eos_idx=self.tokenizer.eos_token_id) | |||
| @PREPROCESSORS.register_module( | |||
| @@ -13,7 +13,7 @@ from .utils.random_help import set_torch_seed | |||
| class OfaBasePreprocessor: | |||
| def __init__(self, cfg, model_dir): | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| @@ -76,6 +76,7 @@ class OfaBasePreprocessor: | |||
| text, | |||
| max_length=self.max_src_length, | |||
| add_special_tokens=False, | |||
| truncation=True, | |||
| return_tensors='pt')['input_ids'].squeeze(0) | |||
| if add_bos: | |||
| inputs = torch.cat([self.bos_item, inputs]) | |||
| @@ -6,24 +6,28 @@ from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.preprocessors.image import load_image | |||
| from modelscope.utils.constant import ModeKeys | |||
| from .base import OfaBasePreprocessor | |||
| class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir) | |||
| super(OfaImageCaptioningPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.Resize( | |||
| (self.patch_image_size, self.patch_image_size), | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| @@ -11,20 +11,22 @@ from .base import OfaBasePreprocessor | |||
| class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaImageClassificationPreprocessor, | |||
| self).__init__(cfg, model_dir) | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.Resize( | |||
| (self.patch_image_size, self.patch_image_size), | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| @@ -6,14 +6,16 @@ from .base import OfaBasePreprocessor | |||
| class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir) | |||
| super(OfaSummarizationPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| source = super().pre_caption( | |||
| @@ -6,14 +6,16 @@ from .base import OfaBasePreprocessor | |||
| class OfaTextClassificationPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir) | |||
| super(OfaTextClassificationPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| text1 = ' '.join( | |||
| @@ -34,5 +36,6 @@ class OfaTextClassificationPreprocessor(OfaBasePreprocessor): | |||
| sample = { | |||
| 'source': inputs, | |||
| 'decoder_prompt': decoder_prompt, | |||
| 'prefix_token': decoder_prompt[:-1], | |||
| } | |||
| return sample | |||
| @@ -8,14 +8,16 @@ from .base import OfaBasePreprocessor | |||
| class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaTextToImageSynthesisPreprocessor, | |||
| self).__init__(cfg, model_dir) | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| self.max_src_length = 64 | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| @@ -50,8 +50,11 @@ def collate_fn(samples, pad_idx, eos_idx): | |||
| if samples[0].get('constraint_mask', None) is not None: | |||
| batch['constraint_masks'] = merge('constraint_mask') | |||
| if samples[0].get('decoder_prompt', None) is not None: | |||
| batch['decoder_prompts'] = np.array( | |||
| [s['decoder_prompt'].tolist() for s in samples]) | |||
| batch['decoder_prompts'] = torch.stack( | |||
| [s['decoder_prompt'] for s in samples], dim=0) | |||
| if samples[0].get('prefix_token', None) is not None: | |||
| batch['prefix_tokens'] = torch.stack( | |||
| [s['prefix_token'] for s in samples], dim=0) | |||
| # For detection and visual grounding | |||
| if samples[0].get('w_resize_ratio', None) is not None: | |||
| batch['w_resize_ratios'] = torch.stack( | |||
| @@ -11,19 +11,22 @@ from .base import OfaBasePreprocessor | |||
| class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir) | |||
| super(OfaVisualEntailmentPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.Resize( | |||
| (self.patch_image_size, self.patch_image_size), | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| @@ -11,19 +11,22 @@ from .base import OfaBasePreprocessor | |||
| class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir) | |||
| super(OfaVisualGroundingPreprocessor, | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.Resize( | |||
| (self.patch_image_size, self.patch_image_size), | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| @@ -11,20 +11,22 @@ from .base import OfaBasePreprocessor | |||
| class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||
| def __init__(self, cfg, model_dir): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| def __init__(self, cfg, model_dir, split, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| cfg(modelscope.utils.config.ConfigDict) : model config | |||
| model_dir (str): model path | |||
| model_dir (str): model path, | |||
| split: data phase | |||
| """ | |||
| super(OfaVisualQuestionAnsweringPreprocessor, | |||
| self).__init__(cfg, model_dir) | |||
| self).__init__(cfg, model_dir, split, *args, **kwargs) | |||
| # Initialize transform | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| lambda image: image.convert('RGB'), | |||
| transforms.Resize((self.patch_image_size, self.patch_image_size), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.Resize( | |||
| (self.patch_image_size, self.patch_image_size), | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=self.mean, std=self.std), | |||
| ]) | |||
| @@ -0,0 +1,131 @@ | |||
| # Copyright 2022 The OFA-Sys Team. | |||
| # All rights reserved. | |||
| # This source code is licensed under the Apache 2.0 license | |||
| # found in the LICENSE file in the root directory. | |||
| import os | |||
| import pickle | |||
| import torch | |||
| class OFAFileDataset: | |||
| def __init__(self, | |||
| file_path, | |||
| selected_col_ids=None, | |||
| dtypes=None, | |||
| separator='\t', | |||
| cached_index=False): | |||
| self.file_path = file_path | |||
| assert os.path.exists( | |||
| self.file_path), 'Error: The local datafile {} not exists!'.format( | |||
| self.file_path) | |||
| self.separator = separator | |||
| if selected_col_ids is None: | |||
| # default to all fields | |||
| self.selected_col_ids = list( | |||
| range( | |||
| len( | |||
| open(self.file_path).readline().rstrip('\n').split( | |||
| self.separator)))) | |||
| else: | |||
| self.selected_col_ids = [ | |||
| int(col_id) for col_id in selected_col_ids.split(',') | |||
| ] | |||
| if dtypes is None: | |||
| # default to str | |||
| self.dtypes = [str for col_id in self.selected_col_ids] | |||
| else: | |||
| self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')] | |||
| assert len(self.dtypes) == len(self.selected_col_ids) | |||
| self.data_cnt = 0 | |||
| try: | |||
| self.slice_id = torch.distributed.get_rank() | |||
| self.slice_count = torch.distributed.get_world_size() | |||
| except Exception: | |||
| self.slice_id = 0 | |||
| self.slice_count = 1 | |||
| self.cached_index = cached_index | |||
| self._init_seek_index() | |||
| self._reader = self._get_reader() | |||
| print('file {} slice_id {} row count {} total row count {}'.format( | |||
| self.file_path, self.slice_id, self.row_count, | |||
| self.total_row_count)) | |||
| def _init_seek_index(self): | |||
| if self.cached_index: | |||
| cache_path = '{}.index'.format(self.file_path) | |||
| assert os.path.exists( | |||
| cache_path), 'cache file {} not exists!'.format(cache_path) | |||
| self.total_row_count, self.lineid_to_offset = pickle.load( | |||
| open(cache_path, 'rb')) | |||
| print( | |||
| 'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping' | |||
| .format(self.file_path, self.slice_id)) | |||
| else: | |||
| # make an iteration over the file to get row_count and line_idx-to-offset mapping | |||
| fp = open(self.file_path, 'r') | |||
| print( | |||
| 'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping' | |||
| .format(self.file_path, self.slice_id)) | |||
| self.total_row_count = 0 | |||
| offset = 0 | |||
| self.lineid_to_offset = [] | |||
| for line in fp: | |||
| self.lineid_to_offset.append(offset) | |||
| self.total_row_count += 1 | |||
| offset += len(line.encode('utf-8')) | |||
| self._compute_start_pos_and_row_count() | |||
| print( | |||
| 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | |||
| .format(self.file_path, self.slice_id)) | |||
| def _compute_start_pos_and_row_count(self): | |||
| self.row_count = self.total_row_count // self.slice_count | |||
| if self.slice_id < self.total_row_count - self.row_count * self.slice_count: | |||
| self.row_count += 1 | |||
| self.start_pos = self.row_count * self.slice_id | |||
| else: | |||
| self.start_pos = self.row_count * self.slice_id + ( | |||
| self.total_row_count - self.row_count * self.slice_count) | |||
| def _get_reader(self): | |||
| fp = open(self.file_path, 'r') | |||
| fp.seek(self.lineid_to_offset[self.start_pos]) | |||
| return fp | |||
| def _seek(self, offset=0): | |||
| try: | |||
| print('slice_id {} seek offset {}'.format(self.slice_id, | |||
| self.start_pos + offset)) | |||
| self._reader.seek(self.lineid_to_offset[self.start_pos + offset]) | |||
| self.data_cnt = offset | |||
| except Exception: | |||
| print('slice_id {} seek offset {}'.format(self.slice_id, offset)) | |||
| self._reader.seek(self.lineid_to_offset[offset]) | |||
| self.data_cnt = offset | |||
| def __del__(self): | |||
| self._reader.close() | |||
| def __len__(self): | |||
| return self.row_count | |||
| def get_total_row_count(self): | |||
| return self.total_row_count | |||
| def __getitem__(self, index): | |||
| if self.data_cnt == self.row_count: | |||
| print('reach the end of datafile, start a new reader') | |||
| self.data_cnt = 0 | |||
| self._reader = self._get_reader() | |||
| column_l = self._reader.readline().rstrip('\n').split(self.separator) | |||
| self.data_cnt += 1 | |||
| column_l = [ | |||
| dtype(column_l[col_id]) | |||
| for col_id, dtype in zip(self.selected_col_ids, self.dtypes) | |||
| ] | |||
| return column_l | |||
| @@ -0,0 +1,52 @@ | |||
| # Copyright 2022 The OFA-Sys Team. | |||
| # All rights reserved. | |||
| # This source code is licensed under the Apache 2.0 license | |||
| # found in the LICENSE file in the root directory. | |||
| from os import path as osp | |||
| from torch.utils.data import Dataset | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks | |||
| from .ofa_file_dataset import OFAFileDataset | |||
| class OFADataset(Dataset): | |||
| def __init__(self, | |||
| model_dir, | |||
| file_path, | |||
| dtypes=None, | |||
| separator='\t', | |||
| cached_index=False, | |||
| split=ModeKeys.TRAIN, | |||
| **kwargs): | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| selected_col_ids = self.cfg.dataset.selected_col_ids | |||
| selected_col_keys = self.cfg.dataset.selected_col_keys | |||
| assert selected_col_ids is not None | |||
| assert selected_col_keys is not None | |||
| self.selected_col_key_l = selected_col_keys.split(',') | |||
| assert len(self.selected_col_key_l) == len(selected_col_ids.split(',')) | |||
| self.dataset = OFAFileDataset( | |||
| file_path=file_path, | |||
| selected_col_ids=selected_col_ids, | |||
| dtypes=dtypes, | |||
| separator=separator, | |||
| cached_index=cached_index) | |||
| self.preprocessor = OfaPreprocessor(model_dir, split) | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| def __getitem__(self, index): | |||
| value_l = self.dataset[index] | |||
| data = dict() | |||
| for key, value in zip(self.selected_col_key_l, value_l): | |||
| data[key] = value | |||
| return self.preprocessor(data) | |||