| @@ -126,7 +126,7 @@ class OfaForAllTasks(TorchModel): | |||||
| return ret | return ret | ||||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | ||||
| if self.cfg.task == Tasks.image_captioning: | |||||
| if not self.model.training and self.cfg.task == Tasks.image_captioning: | |||||
| caption = input[OutputKeys.CAPTION] | caption = input[OutputKeys.CAPTION] | ||||
| result_l = list() | result_l = list() | ||||
| for cap in caption: | for cap in caption: | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import re | import re | ||||
| import string | |||||
| from os import path as osp | from os import path as osp | ||||
| import json | import json | ||||
| @@ -58,6 +59,9 @@ class OfaBasePreprocessor: | |||||
| self.mean = [0.5, 0.5, 0.5] | self.mean = [0.5, 0.5, 0.5] | ||||
| self.std = [0.5, 0.5, 0.5] | self.std = [0.5, 0.5, 0.5] | ||||
| self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | ||||
| self.transtab = str.maketrans( | |||||
| {key: None | |||||
| for key in string.punctuation}) | |||||
| self.constraint_trie = None | self.constraint_trie = None | ||||
| if self.cfg.model.get('answer2label', None): | if self.cfg.model.get('answer2label', None): | ||||
| ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) | ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) | ||||
| @@ -1,4 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| import torch | import torch | ||||
| @@ -43,6 +44,17 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
| else: | else: | ||||
| return self._build_infer_sample(data) | return self._build_infer_sample(data) | ||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| sample = self._build_infer_sample(data) | |||||
| target = data['text'] | |||||
| target = target.translate(self.transtab).strip() | |||||
| target_token_list = target.strip().split() | |||||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | |||||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||||
| sample['prev_output_tokens'] = torch.cat( | |||||
| [self.bos_item, sample['target'][:-1]]) | |||||
| return sample | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| image = data['image'] if isinstance( | image = data['image'] if isinstance( | ||||
| data['image'], Image.Image) else load_image(data['image']) | data['image'], Image.Image) else load_image(data['image']) | ||||
| @@ -55,12 +67,3 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
| 'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
| } | } | ||||
| return sample | return sample | ||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| sample = self._build_infer_sample(data) | |||||
| target = data['target'] | |||||
| target = target.translate(self.transtab).strip() | |||||
| target_token_list = target.strip().split() | |||||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | |||||
| sample['target'] = self.tokenize_text(target) | |||||
| return sample | |||||
| @@ -79,5 +79,6 @@ class TorchAMPOptimizerHook(OptimizerHook): | |||||
| self.scaler.step(trainer.optimizer) | self.scaler.step(trainer.optimizer) | ||||
| self.scaler.update(self._scale_update_param) | self.scaler.update(self._scale_update_param) | ||||
| trainer.optimizer.zero_grad() | trainer.optimizer.zero_grad() | ||||
| print('xcxcxcxcxc: optimizer step') | |||||
| setattr(self._model, 'forward', self._ori_model_forward) | setattr(self._model, 'forward', self._ori_model_forward) | ||||
| @@ -1,6 +1,6 @@ | |||||
| import math | |||||
| import os | import os | ||||
| from functools import partial | from functools import partial | ||||
| from typing import Dict, Optional | |||||
| from datasets import load_dataset | from datasets import load_dataset | ||||
| from torch import distributed as dist | from torch import distributed as dist | ||||
| @@ -27,13 +27,7 @@ class OFATrainer(EpochBasedTrainer): | |||||
| model_dir = model.model_dir | model_dir = model.model_dir | ||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | ||||
| cfg = Config.from_file(cfg_file) | cfg = Config.from_file(cfg_file) | ||||
| dataset = load_dataset( | |||||
| cfg.dataset.script, | |||||
| data_files=cfg.dataset.hf_dataset, | |||||
| sep=cfg.dataset.sep, | |||||
| ) | |||||
| dataset = MsDataset.from_hf_dataset( | |||||
| dataset.rename_columns(cfg.dataset.column_map)) | |||||
| dataset = self._build_dataset_with_config(cfg) | |||||
| preprocessor = { | preprocessor = { | ||||
| ConfigKeys.train: | ConfigKeys.train: | ||||
| OfaPreprocessor( | OfaPreprocessor( | ||||
| @@ -42,9 +36,11 @@ class OFATrainer(EpochBasedTrainer): | |||||
| OfaPreprocessor( | OfaPreprocessor( | ||||
| model_dir=model_dir, mode=ModeKeys.EVAL, no_collate=True), | model_dir=model_dir, mode=ModeKeys.EVAL, no_collate=True), | ||||
| } | } | ||||
| epoch_steps = len(dataset['train']) // ( | |||||
| cfg.train.optimizer_hook.cumulative_iters | |||||
| * cfg.train.dataloader.batch_size_per_gpu) | |||||
| # use torchrun launch | |||||
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |||||
| epoch_steps = math.ceil( | |||||
| len(dataset['train']) / # noqa | |||||
| (cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa | |||||
| cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | ||||
| cfg.train.criterion.tokenizer = model.tokenizer | cfg.train.criterion.tokenizer = model.tokenizer | ||||
| self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | ||||
| @@ -104,3 +100,24 @@ class OFATrainer(EpochBasedTrainer): | |||||
| else: | else: | ||||
| self.log_buffer.update(train_outputs['log_vars']) | self.log_buffer.update(train_outputs['log_vars']) | ||||
| self.train_outputs = train_outputs | self.train_outputs = train_outputs | ||||
| def _build_dataset_with_config(self, cfg): | |||||
| if hasattr(cfg.dataset, 'hf_dataset'): | |||||
| dataset = load_dataset( | |||||
| cfg.dataset.script, | |||||
| data_files=cfg.dataset.hf_dataset, | |||||
| sep=cfg.dataset.sep, | |||||
| ) | |||||
| dataset = MsDataset.from_hf_dataset( | |||||
| dataset.rename_columns(cfg.dataset.column_map)) | |||||
| return dataset | |||||
| elif hasattr(cfg.dataset, 'ms_dataset'): | |||||
| dataset_d = dict() | |||||
| for key in cfg.dataset.ms_dataset.keys(): | |||||
| dataset_d[key] = MsDataset.load(**cfg.dataset.ms_dataset[key]) | |||||
| dataset_d[key] = MsDataset.from_hf_dataset( | |||||
| dataset_d[key]._hf_ds.rename_columns( | |||||
| cfg.dataset.column_map)) | |||||
| return dataset_d | |||||
| else: | |||||
| raise NotImplementedError | |||||
| @@ -216,7 +216,6 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self._max_epochs = self.cfg.train.max_epochs | self._max_epochs = self.cfg.train.max_epochs | ||||
| else: | else: | ||||
| self._max_epochs = kwargs['max_epochs'] | self._max_epochs = kwargs['max_epochs'] | ||||
| self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) | self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) | ||||
| self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) | self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) | ||||
| if self._train_iters_per_epoch is None and hasattr( | if self._train_iters_per_epoch is None and hasattr( | ||||