diff --git a/modelscope/models/multi_modal/ofa/generate/search.py b/modelscope/models/multi_modal/ofa/generate/search.py index 63ecb0a9..0dcaf6b3 100644 --- a/modelscope/models/multi_modal/ofa/generate/search.py +++ b/modelscope/models/multi_modal/ofa/generate/search.py @@ -148,7 +148,7 @@ class BeamSearch(Search): scores_buf = top_prediction[0] indices_buf = top_prediction[1] # Project back into relative indices and beams - beams_buf = indices_buf // vocab_size + beams_buf = torch.div(indices_buf, vocab_size, rounding_mode='floor') indices_buf = indices_buf.fmod(vocab_size) # At this point, beams_buf and indices_buf are single-dim and contain relative indices diff --git a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py index 590fb67b..9d427836 100644 --- a/modelscope/models/multi_modal/ofa/generate/sequence_generator.py +++ b/modelscope/models/multi_modal/ofa/generate/sequence_generator.py @@ -385,12 +385,7 @@ class SequenceGenerator(nn.Module): attn = torch.empty(bsz * beam_size, avg_attn_scores.size(1), max_len + 2).to(scores) - # print("+++++++ debug attention shape +++++++") - # print("attn", attn.shape) - # print("avg_attn_scores", avg_attn_scores.shape) attn[:, :, step + 1].copy_(avg_attn_scores) - # print("attn[:, :, step + 1]", attn[:, :, step + 1].shape) - # print("attn", attn.shape) scores = scores.type_as(lprobs) eos_bbsz_idx = torch.empty(0).to( @@ -403,7 +398,8 @@ class SequenceGenerator(nn.Module): if self.should_set_src_lengths: self.search.set_src_lengths(src_lengths) - if self.repeat_ngram_blocker is not None: + if self.repeat_ngram_blocker is not None and step > prefix_tokens.size( + 1): lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) @@ -415,7 +411,6 @@ class SequenceGenerator(nn.Module): tokens[:, :step + 1], original_batch_idxs, ) - # cand_bbsz_idx contains beam indices for the top candidate # hypotheses, with a range of values: [0, bsz*beam_size), # and dimensions: [bsz, cand_size] @@ -671,7 +666,7 @@ class SequenceGenerator(nn.Module): cum_unfin.append(prev) cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) - unfin_idx = bbsz_idx // beam_size + unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode='floor') sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) # Create a set of "{sent}{unfin_idx}", where diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py index 01cc02f9..4de35741 100755 --- a/modelscope/models/multi_modal/ofa/modeling_ofa.py +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -114,7 +114,8 @@ def make_image_bucket_position(bucket_size, num_relative_distance): """ coords_h = torch.arange(bucket_size) coords_w = torch.arange(bucket_size) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid([coords_h, coords_w], + indexing='ij')) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - \ coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww diff --git a/modelscope/models/multi_modal/ofa/utils/constant.py b/modelscope/models/multi_modal/ofa/utils/constant.py index 984da443..124afefa 100644 --- a/modelscope/models/multi_modal/ofa/utils/constant.py +++ b/modelscope/models/multi_modal/ofa/utils/constant.py @@ -7,7 +7,7 @@ OFA_TASK_KEY_MAPPING = { Tasks.summarization: OutputKeys.TEXT, Tasks.visual_question_answering: OutputKeys.TEXT, Tasks.visual_grounding: OutputKeys.BOXES, - Tasks.text_classification: (OutputKeys.SCORES, OutputKeys.LABELS), + Tasks.text_classification: OutputKeys.LABELS, Tasks.image_classification: OutputKeys.LABELS, - Tasks.visual_entailment: (OutputKeys.SCORES, OutputKeys.LABELS), + Tasks.visual_entailment: OutputKeys.LABELS, } diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 860b68d3..80471e3c 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -127,10 +127,23 @@ class OfaForAllTasks(TorchModel): return input def _text_gen_inference(self, input): + import pdb + pdb.set_trace() input = move_to_device(input, self._device) - gen_output = self.generator.generate([self.model], input) - gen = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] - result = self.tokenizer.batch_decode(gen, skip_special_tokens=True) + if 'prefix_tokens' in input: + gen_output = self.generator.generate( + [self.model], input, prefix_tokens=input['prefix_tokens']) + else: + gen_output = self.generator.generate([self.model], input) + gen_l = list() + for i in range(len(gen_output)): + if 'prefix_tokens' in input: + prefix_tokens = input['prefix_tokens'] + gen_l.append( + gen_output[i][0]['tokens'][len(prefix_tokens[i]):]) + else: + gen_l.append(gen_output[i][0]['tokens']) + result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) # text generation tasks have no score ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} if self.cfg.task.endswith('classification'): diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 65578e6a..46648832 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp +from io import BytesIO from typing import Any, Dict, List, Union import torch @@ -8,6 +9,7 @@ from PIL import Image from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors from modelscope.pipelines.base import Input +from modelscope.preprocessors.image import load_image from modelscope.utils.config import Config from modelscope.utils.constant import Fields, ModelFile, Tasks from .base import Preprocessor @@ -71,20 +73,32 @@ class OfaPreprocessor(Preprocessor): data[key] = item return data + def _compatible_with_pretrain(self, data): + if 'image' in data and self.cfg.model.get('type', None) == 'ofa': + image = load_image(data['image']) + img_buffer = BytesIO() + image.save(img_buffer, format='JPEG') + data['image'] = Image.open(img_buffer) + return data + def __call__(self, input: Union[str, tuple, Dict[str, Any]], *args, **kwargs) -> Dict[str, Any]: if isinstance(input, dict): data = input else: data = self._build_dict(input) + data = self._compatible_with_pretrain(data) sample = self.preprocess(data) str_data = dict() for k, v in data.items(): str_data[k] = str(v) sample['sample'] = str_data - return collate_fn([sample], - pad_idx=self.tokenizer.pad_token_id, - eos_idx=self.tokenizer.eos_token_id) + if kwargs.get('no_collate', None): + return sample + else: + return collate_fn([sample], + pad_idx=self.tokenizer.pad_token_id, + eos_idx=self.tokenizer.eos_token_id) @PREPROCESSORS.register_module( diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py index fb9d06cd..69286f69 100644 --- a/modelscope/preprocessors/ofa/base.py +++ b/modelscope/preprocessors/ofa/base.py @@ -13,7 +13,7 @@ from .utils.random_help import set_torch_seed class OfaBasePreprocessor: - def __init__(self, cfg, model_dir): + def __init__(self, cfg, model_dir, split, *args, **kwargs): """preprocess the data via the vocab.txt from the `model_dir` path Args: @@ -76,6 +76,7 @@ class OfaBasePreprocessor: text, max_length=self.max_src_length, add_special_tokens=False, + truncation=True, return_tensors='pt')['input_ids'].squeeze(0) if add_bos: inputs = torch.cat([self.bos_item, inputs]) diff --git a/modelscope/preprocessors/ofa/image_captioning.py b/modelscope/preprocessors/ofa/image_captioning.py index 264c8e04..3ea4ccb2 100644 --- a/modelscope/preprocessors/ofa/image_captioning.py +++ b/modelscope/preprocessors/ofa/image_captioning.py @@ -6,24 +6,28 @@ from PIL import Image from torchvision import transforms from modelscope.preprocessors.image import load_image +from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + split: data phase """ - super(OfaImageCaptioningPreprocessor, self).__init__(cfg, model_dir) + super(OfaImageCaptioningPreprocessor, + self).__init__(cfg, model_dir, split, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) diff --git a/modelscope/preprocessors/ofa/image_classification.py b/modelscope/preprocessors/ofa/image_classification.py index 30289613..a0cd0990 100644 --- a/modelscope/preprocessors/ofa/image_classification.py +++ b/modelscope/preprocessors/ofa/image_classification.py @@ -11,20 +11,22 @@ from .base import OfaBasePreprocessor class OfaImageClassificationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + split: data phase """ super(OfaImageClassificationPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, split, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) diff --git a/modelscope/preprocessors/ofa/summarization.py b/modelscope/preprocessors/ofa/summarization.py index fd5113cd..00ae9bf9 100644 --- a/modelscope/preprocessors/ofa/summarization.py +++ b/modelscope/preprocessors/ofa/summarization.py @@ -6,14 +6,16 @@ from .base import OfaBasePreprocessor class OfaSummarizationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + split: data phase """ - super(OfaSummarizationPreprocessor, self).__init__(cfg, model_dir) + super(OfaSummarizationPreprocessor, + self).__init__(cfg, model_dir, split, *args, **kwargs) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: source = super().pre_caption( diff --git a/modelscope/preprocessors/ofa/text_classification.py b/modelscope/preprocessors/ofa/text_classification.py index 1a3f84fd..25981e65 100644 --- a/modelscope/preprocessors/ofa/text_classification.py +++ b/modelscope/preprocessors/ofa/text_classification.py @@ -6,14 +6,16 @@ from .base import OfaBasePreprocessor class OfaTextClassificationPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + split: data phase """ - super(OfaTextClassificationPreprocessor, self).__init__(cfg, model_dir) + super(OfaTextClassificationPreprocessor, + self).__init__(cfg, model_dir, split, *args, **kwargs) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: text1 = ' '.join( @@ -34,5 +36,6 @@ class OfaTextClassificationPreprocessor(OfaBasePreprocessor): sample = { 'source': inputs, 'decoder_prompt': decoder_prompt, + 'prefix_token': decoder_prompt[:-1], } return sample diff --git a/modelscope/preprocessors/ofa/text_to_image_synthesis.py b/modelscope/preprocessors/ofa/text_to_image_synthesis.py index 9dbba921..56198e67 100644 --- a/modelscope/preprocessors/ofa/text_to_image_synthesis.py +++ b/modelscope/preprocessors/ofa/text_to_image_synthesis.py @@ -8,14 +8,16 @@ from .base import OfaBasePreprocessor class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: - model_dir (str): model path + cfg(modelscope.utils.config.ConfigDict) : model config + model_dir (str): model path, + split: data phase """ super(OfaTextToImageSynthesisPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, split, *args, **kwargs) self.max_src_length = 64 def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py index a473335b..82258e8b 100644 --- a/modelscope/preprocessors/ofa/utils/collate.py +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -50,8 +50,11 @@ def collate_fn(samples, pad_idx, eos_idx): if samples[0].get('constraint_mask', None) is not None: batch['constraint_masks'] = merge('constraint_mask') if samples[0].get('decoder_prompt', None) is not None: - batch['decoder_prompts'] = np.array( - [s['decoder_prompt'].tolist() for s in samples]) + batch['decoder_prompts'] = torch.stack( + [s['decoder_prompt'] for s in samples], dim=0) + if samples[0].get('prefix_token', None) is not None: + batch['prefix_tokens'] = torch.stack( + [s['prefix_token'] for s in samples], dim=0) # For detection and visual grounding if samples[0].get('w_resize_ratio', None) is not None: batch['w_resize_ratios'] = torch.stack( diff --git a/modelscope/preprocessors/ofa/visual_entailment.py b/modelscope/preprocessors/ofa/visual_entailment.py index 72e88d75..45c719b1 100644 --- a/modelscope/preprocessors/ofa/visual_entailment.py +++ b/modelscope/preprocessors/ofa/visual_entailment.py @@ -11,19 +11,22 @@ from .base import OfaBasePreprocessor class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + split: data phase """ - super(OfaVisualEntailmentPreprocessor, self).__init__(cfg, model_dir) + super(OfaVisualEntailmentPreprocessor, + self).__init__(cfg, model_dir, split, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) diff --git a/modelscope/preprocessors/ofa/visual_grounding.py b/modelscope/preprocessors/ofa/visual_grounding.py index eebc4cf2..eaaed0ef 100644 --- a/modelscope/preprocessors/ofa/visual_grounding.py +++ b/modelscope/preprocessors/ofa/visual_grounding.py @@ -11,19 +11,22 @@ from .base import OfaBasePreprocessor class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + split: data phase """ - super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir) + super(OfaVisualGroundingPreprocessor, + self).__init__(cfg, model_dir, split, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) diff --git a/modelscope/preprocessors/ofa/visual_question_answering.py b/modelscope/preprocessors/ofa/visual_question_answering.py index b11af9f6..bce18c95 100644 --- a/modelscope/preprocessors/ofa/visual_question_answering.py +++ b/modelscope/preprocessors/ofa/visual_question_answering.py @@ -11,20 +11,22 @@ from .base import OfaBasePreprocessor class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): - def __init__(self, cfg, model_dir): - """preprocess the data via the vocab.txt from the `model_dir` path + def __init__(self, cfg, model_dir, split, *args, **kwargs): + """preprocess the data Args: cfg(modelscope.utils.config.ConfigDict) : model config - model_dir (str): model path + model_dir (str): model path, + split: data phase """ super(OfaVisualQuestionAnsweringPreprocessor, - self).__init__(cfg, model_dir) + self).__init__(cfg, model_dir, split, *args, **kwargs) # Initialize transform self.patch_resize_transform = transforms.Compose([ lambda image: image.convert('RGB'), - transforms.Resize((self.patch_image_size, self.patch_image_size), - interpolation=Image.BICUBIC), + transforms.Resize( + (self.patch_image_size, self.patch_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std), ]) diff --git a/modelscope/trainers/multi_modal/ofa/__init__.py b/modelscope/trainers/multi_modal/ofa/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py b/modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py new file mode 100644 index 00000000..2f64f9ff --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py @@ -0,0 +1,131 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. + +import os +import pickle + +import torch + + +class OFAFileDataset: + + def __init__(self, + file_path, + selected_col_ids=None, + dtypes=None, + separator='\t', + cached_index=False): + self.file_path = file_path + assert os.path.exists( + self.file_path), 'Error: The local datafile {} not exists!'.format( + self.file_path) + + self.separator = separator + if selected_col_ids is None: + # default to all fields + self.selected_col_ids = list( + range( + len( + open(self.file_path).readline().rstrip('\n').split( + self.separator)))) + else: + self.selected_col_ids = [ + int(col_id) for col_id in selected_col_ids.split(',') + ] + if dtypes is None: + # default to str + self.dtypes = [str for col_id in self.selected_col_ids] + else: + self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')] + assert len(self.dtypes) == len(self.selected_col_ids) + + self.data_cnt = 0 + try: + self.slice_id = torch.distributed.get_rank() + self.slice_count = torch.distributed.get_world_size() + except Exception: + self.slice_id = 0 + self.slice_count = 1 + self.cached_index = cached_index + self._init_seek_index() + self._reader = self._get_reader() + print('file {} slice_id {} row count {} total row count {}'.format( + self.file_path, self.slice_id, self.row_count, + self.total_row_count)) + + def _init_seek_index(self): + if self.cached_index: + cache_path = '{}.index'.format(self.file_path) + assert os.path.exists( + cache_path), 'cache file {} not exists!'.format(cache_path) + self.total_row_count, self.lineid_to_offset = pickle.load( + open(cache_path, 'rb')) + print( + 'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping' + .format(self.file_path, self.slice_id)) + else: + # make an iteration over the file to get row_count and line_idx-to-offset mapping + fp = open(self.file_path, 'r') + print( + 'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping' + .format(self.file_path, self.slice_id)) + self.total_row_count = 0 + offset = 0 + self.lineid_to_offset = [] + for line in fp: + self.lineid_to_offset.append(offset) + self.total_row_count += 1 + offset += len(line.encode('utf-8')) + self._compute_start_pos_and_row_count() + print( + 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' + .format(self.file_path, self.slice_id)) + + def _compute_start_pos_and_row_count(self): + self.row_count = self.total_row_count // self.slice_count + if self.slice_id < self.total_row_count - self.row_count * self.slice_count: + self.row_count += 1 + self.start_pos = self.row_count * self.slice_id + else: + self.start_pos = self.row_count * self.slice_id + ( + self.total_row_count - self.row_count * self.slice_count) + + def _get_reader(self): + fp = open(self.file_path, 'r') + fp.seek(self.lineid_to_offset[self.start_pos]) + return fp + + def _seek(self, offset=0): + try: + print('slice_id {} seek offset {}'.format(self.slice_id, + self.start_pos + offset)) + self._reader.seek(self.lineid_to_offset[self.start_pos + offset]) + self.data_cnt = offset + except Exception: + print('slice_id {} seek offset {}'.format(self.slice_id, offset)) + self._reader.seek(self.lineid_to_offset[offset]) + self.data_cnt = offset + + def __del__(self): + self._reader.close() + + def __len__(self): + return self.row_count + + def get_total_row_count(self): + return self.total_row_count + + def __getitem__(self, index): + if self.data_cnt == self.row_count: + print('reach the end of datafile, start a new reader') + self.data_cnt = 0 + self._reader = self._get_reader() + column_l = self._reader.readline().rstrip('\n').split(self.separator) + self.data_cnt += 1 + column_l = [ + dtype(column_l[col_id]) + for col_id, dtype in zip(self.selected_col_ids, self.dtypes) + ] + return column_l diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py new file mode 100644 index 00000000..92a22bb4 --- /dev/null +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -0,0 +1,52 @@ +# Copyright 2022 The OFA-Sys Team. +# All rights reserved. +# This source code is licensed under the Apache 2.0 license +# found in the LICENSE file in the root directory. +from os import path as osp + +from torch.utils.data import Dataset + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.preprocessors.multi_modal import OfaPreprocessor +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModeKeys, ModelFile, Tasks +from .ofa_file_dataset import OFAFileDataset + + +class OFADataset(Dataset): + + def __init__(self, + model_dir, + file_path, + dtypes=None, + separator='\t', + cached_index=False, + split=ModeKeys.TRAIN, + **kwargs): + self.cfg = Config.from_file( + osp.join(model_dir, ModelFile.CONFIGURATION)) + selected_col_ids = self.cfg.dataset.selected_col_ids + selected_col_keys = self.cfg.dataset.selected_col_keys + + assert selected_col_ids is not None + assert selected_col_keys is not None + self.selected_col_key_l = selected_col_keys.split(',') + assert len(self.selected_col_key_l) == len(selected_col_ids.split(',')) + + self.dataset = OFAFileDataset( + file_path=file_path, + selected_col_ids=selected_col_ids, + dtypes=dtypes, + separator=separator, + cached_index=cached_index) + self.preprocessor = OfaPreprocessor(model_dir, split) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + value_l = self.dataset[index] + data = dict() + for key, value in zip(self.selected_col_key_l, value_l): + data[key] = value + return self.preprocessor(data)