| @@ -129,8 +129,7 @@ class OfaForAllTasks(TorchModel): | |||||
| result_l = list() | result_l = list() | ||||
| for cap in caption: | for cap in caption: | ||||
| result_l.append(cap.translate(self.transtab).strip()) | result_l.append(cap.translate(self.transtab).strip()) | ||||
| input[OutputKeys.CAPTION] = caption | |||||
| input[OutputKeys.CAPTION] = result_l | |||||
| return input | return input | ||||
| def _text_gen_inference(self, input): | def _text_gen_inference(self, input): | ||||
| @@ -182,6 +181,8 @@ class OfaForAllTasks(TorchModel): | |||||
| encoder_input[key] = input['net_input'][key] | encoder_input[key] = input['net_input'][key] | ||||
| encoder_out = self.model.encoder(**encoder_input) | encoder_out = self.model.encoder(**encoder_input) | ||||
| valid_result = [] | valid_result = [] | ||||
| import pdb | |||||
| pdb.set_trace() | |||||
| for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): | for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): | ||||
| valid_size = len(val_ans) | valid_size = len(val_ans) | ||||
| valid_tgt_items = [ | valid_tgt_items = [ | ||||
| @@ -1,133 +0,0 @@ | |||||
| # Copyright 2022 The OFA-Sys Team. | |||||
| # All rights reserved. | |||||
| # This source code is licensed under the Apache 2.0 license | |||||
| # found in the LICENSE file in the root directory. | |||||
| import os | |||||
| import pickle | |||||
| import torch | |||||
| class OFAFileDataset: | |||||
| def __init__(self, | |||||
| file_path, | |||||
| selected_col_ids=None, | |||||
| dtypes=None, | |||||
| separator='\t', | |||||
| cached_index=False): | |||||
| self.file_path = file_path | |||||
| assert os.path.exists( | |||||
| self.file_path), 'Error: The local datafile {} not exists!'.format( | |||||
| self.file_path) | |||||
| self.separator = separator | |||||
| if selected_col_ids is None: | |||||
| # default to all fields | |||||
| self.selected_col_ids = list( | |||||
| range( | |||||
| len( | |||||
| open(self.file_path).readline().rstrip('\n').split( | |||||
| self.separator)))) | |||||
| else: | |||||
| self.selected_col_ids = [ | |||||
| int(col_id) for col_id in selected_col_ids.split(',') | |||||
| ] | |||||
| if dtypes is None: | |||||
| # default to str | |||||
| self.dtypes = [str for col_id in self.selected_col_ids] | |||||
| else: | |||||
| self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')] | |||||
| assert len(self.dtypes) == len(self.selected_col_ids) | |||||
| self.data_cnt = 0 | |||||
| try: | |||||
| self.slice_id = torch.distributed.get_rank() | |||||
| self.slice_count = torch.distributed.get_world_size() | |||||
| except Exception: | |||||
| self.slice_id = 0 | |||||
| self.slice_count = 1 | |||||
| self.cached_index = cached_index | |||||
| self._init_seek_index() | |||||
| self._reader = self._get_reader() | |||||
| print('file {} slice_id {} row count {} total row count {}'.format( | |||||
| self.file_path, self.slice_id, self.row_count, | |||||
| self.total_row_count)) | |||||
| def _init_seek_index(self): | |||||
| if self.cached_index: | |||||
| cache_path = '{}.index'.format(self.file_path) | |||||
| assert os.path.exists( | |||||
| cache_path), 'cache file {} not exists!'.format(cache_path) | |||||
| self.total_row_count, self.lineid_to_offset = pickle.load( | |||||
| open(cache_path, 'rb')) | |||||
| print( | |||||
| 'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping' | |||||
| .format(self.file_path, self.slice_id)) | |||||
| else: | |||||
| # make an iteration over the file to get row_count and line_idx-to-offset mapping | |||||
| fp = open(self.file_path, 'r') | |||||
| print( | |||||
| 'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping' | |||||
| .format(self.file_path, self.slice_id)) | |||||
| self.total_row_count = 0 | |||||
| offset = 0 | |||||
| self.lineid_to_offset = [] | |||||
| for line in fp: | |||||
| self.lineid_to_offset.append(offset) | |||||
| self.total_row_count += 1 | |||||
| offset += len(line.encode('utf-8')) | |||||
| pickle.dump(self.lineid_to_offset, | |||||
| open('{}.index'.format(self.file_path), 'wb')) | |||||
| self._compute_start_pos_and_row_count() | |||||
| print( | |||||
| 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | |||||
| .format(self.file_path, self.slice_id)) | |||||
| def _compute_start_pos_and_row_count(self): | |||||
| self.row_count = self.total_row_count // self.slice_count | |||||
| if self.slice_id < self.total_row_count - self.row_count * self.slice_count: | |||||
| self.row_count += 1 | |||||
| self.start_pos = self.row_count * self.slice_id | |||||
| else: | |||||
| self.start_pos = self.row_count * self.slice_id + ( | |||||
| self.total_row_count - self.row_count * self.slice_count) | |||||
| def _get_reader(self): | |||||
| fp = open(self.file_path, 'r') | |||||
| fp.seek(self.lineid_to_offset[self.start_pos]) | |||||
| return fp | |||||
| def _seek(self, offset=0): | |||||
| try: | |||||
| print('slice_id {} seek offset {}'.format(self.slice_id, | |||||
| self.start_pos + offset)) | |||||
| self._reader.seek(self.lineid_to_offset[self.start_pos + offset]) | |||||
| self.data_cnt = offset | |||||
| except Exception: | |||||
| print('slice_id {} seek offset {}'.format(self.slice_id, offset)) | |||||
| self._reader.seek(self.lineid_to_offset[offset]) | |||||
| self.data_cnt = offset | |||||
| def __del__(self): | |||||
| self._reader.close() | |||||
| def __len__(self): | |||||
| return self.row_count | |||||
| def get_total_row_count(self): | |||||
| return self.total_row_count | |||||
| def __getitem__(self, index): | |||||
| if self.data_cnt == self.row_count: | |||||
| print('reach the end of datafile, start a new reader') | |||||
| self.data_cnt = 0 | |||||
| self._reader = self._get_reader() | |||||
| column_l = self._reader.readline().rstrip('\n').split(self.separator) | |||||
| self.data_cnt += 1 | |||||
| column_l = [ | |||||
| dtype(column_l[col_id]) | |||||
| for col_id, dtype in zip(self.selected_col_ids, self.dtypes) | |||||
| ] | |||||
| return column_l | |||||
| @@ -65,7 +65,7 @@ class OFATrainer(EpochBasedTrainer): | |||||
| kwargs['launcher'] = cfg.train.launcher | kwargs['launcher'] = cfg.train.launcher | ||||
| if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | ||||
| kwargs['use_fp16'] = cfg.train.use_fp16 | kwargs['use_fp16'] = cfg.train.use_fp16 | ||||
| kwargs['to_tensor'] = False | |||||
| super().__init__( | super().__init__( | ||||
| cfg_file=cfg_file, | cfg_file=cfg_file, | ||||
| model=model, | model=model, | ||||
| @@ -167,19 +167,20 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| device_name = f'cuda:{local_rank}' | device_name = f'cuda:{local_rank}' | ||||
| self.device = create_device(device_name) | self.device = create_device(device_name) | ||||
| self.train_dataset = self.to_task_dataset( | self.train_dataset = self.to_task_dataset( | ||||
| train_dataset, | train_dataset, | ||||
| mode=ModeKeys.TRAIN, | mode=ModeKeys.TRAIN, | ||||
| task_data_config=self.cfg.dataset.get('train', None) if hasattr( | task_data_config=self.cfg.dataset.get('train', None) if hasattr( | ||||
| self.cfg, 'dataset') else None, | self.cfg, 'dataset') else None, | ||||
| preprocessor=self.train_preprocessor) | |||||
| preprocessor=self.train_preprocessor, | |||||
| **kwargs) | |||||
| self.eval_dataset = self.to_task_dataset( | self.eval_dataset = self.to_task_dataset( | ||||
| eval_dataset, | eval_dataset, | ||||
| mode=ModeKeys.EVAL, | mode=ModeKeys.EVAL, | ||||
| task_data_config=self.cfg.dataset.get('val', None) if hasattr( | task_data_config=self.cfg.dataset.get('val', None) if hasattr( | ||||
| self.cfg, 'dataset') else None, | self.cfg, 'dataset') else None, | ||||
| preprocessor=self.eval_preprocessor) | |||||
| preprocessor=self.eval_preprocessor, | |||||
| **kwargs) | |||||
| self.train_data_collator, self.eval_default_collate = None, None | self.train_data_collator, self.eval_default_collate = None, None | ||||
| if isinstance(data_collator, Mapping): | if isinstance(data_collator, Mapping): | ||||
| @@ -305,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| datasets: Union[Dataset, List[Dataset]], | datasets: Union[Dataset, List[Dataset]], | ||||
| mode: str, | mode: str, | ||||
| task_data_config: Config = None, | task_data_config: Config = None, | ||||
| preprocessor: Optional[Preprocessor] = None): | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """Build the task specific dataset processor for this trainer. | """Build the task specific dataset processor for this trainer. | ||||
| Returns: The task dataset processor for the task. If no result for the very model-type and task, | Returns: The task dataset processor for the task. If no result for the very model-type and task, | ||||
| the default TaskDataset will be returned. | the default TaskDataset will be returned. | ||||
| """ | """ | ||||
| try: | try: | ||||
| to_tensor = kwargs.get('to_tensor', True) | |||||
| if not datasets: | if not datasets: | ||||
| return datasets | return datasets | ||||
| if isinstance(datasets, TorchTaskDataset): | if isinstance(datasets, TorchTaskDataset): | ||||
| @@ -327,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| return datasets.to_torch_dataset( | return datasets.to_torch_dataset( | ||||
| task_data_config=task_data_config, | task_data_config=task_data_config, | ||||
| task_name=self.cfg.task, | task_name=self.cfg.task, | ||||
| preprocessors=preprocessor) | |||||
| preprocessors=preprocessor, | |||||
| to_tensor=to_tensor) | |||||
| elif isinstance(datasets, List) and isinstance( | elif isinstance(datasets, List) and isinstance( | ||||
| datasets[0], MsDataset): | datasets[0], MsDataset): | ||||
| if task_data_config is None: | if task_data_config is None: | ||||
| @@ -341,7 +345,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| d.to_torch_dataset( | d.to_torch_dataset( | ||||
| task_data_config=task_data_config, | task_data_config=task_data_config, | ||||
| task_name=self.cfg.task, | task_name=self.cfg.task, | ||||
| preprocessors=preprocessor) for d in datasets | |||||
| preprocessors=preprocessor, | |||||
| to_tensor=to_tensor) for d in datasets | |||||
| ] | ] | ||||
| cfg = ConfigDict( | cfg = ConfigDict( | ||||
| type=self.cfg.task, mode=mode, datasets=datasets) | type=self.cfg.task, mode=mode, datasets=datasets) | ||||
| @@ -94,8 +94,11 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_text_classification_with_model(self): | def test_run_with_text_classification_with_model(self): | ||||
| # model = Model.from_pretrained( | |||||
| # 'damo/ofa_text-classification_mnli_large_en') | |||||
| model = Model.from_pretrained( | model = Model.from_pretrained( | ||||
| 'damo/ofa_text-classification_mnli_large_en') | |||||
| '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | |||||
| ) | |||||
| ofa_pipe = pipeline(Tasks.text_classification, model=model) | ofa_pipe = pipeline(Tasks.text_classification, model=model) | ||||
| text = 'One of our number will carry out your instructions minutely.' | text = 'One of our number will carry out your instructions minutely.' | ||||
| text2 = 'A member of my team will execute your orders with immense precision.' | text2 = 'A member of my team will execute your orders with immense precision.' | ||||
| @@ -12,11 +12,10 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_trainer(self): | def test_trainer(self): | ||||
| model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt' | model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt' | ||||
| model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | |||||
| self.trainer = OFATrainer(model_id) | |||||
| self.trainer = OFATrainer(model_id, launcher='pytorch') | |||||
| self.trainer.train() | self.trainer.train() | ||||
| if os.path.exists(self.trainer.work_dir): | if os.path.exists(self.trainer.work_dir): | ||||
| shutil.rmtree(self.trainer.work_dir) | |||||
| pass | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||