新增ofa的finetune能力
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10540701
master
| @@ -27,15 +27,21 @@ class AccuracyMetric(Metric): | |||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||
| ground_truths = inputs[label_name] | |||
| eval_results = outputs[label_name] | |||
| for key in [ | |||
| OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||
| OutputKeys.LABELS, OutputKeys.SCORES | |||
| ]: | |||
| if key in outputs and outputs[key] is not None: | |||
| eval_results = outputs[key] | |||
| break | |||
| assert type(ground_truths) == type(eval_results) | |||
| if isinstance(ground_truths, list): | |||
| self.preds.extend(eval_results) | |||
| self.labels.extend(ground_truths) | |||
| elif isinstance(ground_truths, np.ndarray): | |||
| self.preds.extend(eval_results.tolist()) | |||
| self.labels.extend(ground_truths.tolist()) | |||
| else: | |||
| raise 'only support list or np.ndarray' | |||
| for truth in ground_truths: | |||
| self.labels.append(truth) | |||
| for result in eval_results: | |||
| if isinstance(truth, str): | |||
| self.preds.append(result.strip().replace(' ', '')) | |||
| else: | |||
| self.preds.append(result) | |||
| def evaluate(self): | |||
| assert len(self.preds) == len(self.labels) | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Dict | |||
| import numpy as np | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.NED) | |||
| class NedMetric(Metric): | |||
| """The ned metric computation class for classification classes. | |||
| This metric class calculates the levenshtein distance between sentences for the whole input batches. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.preds = [] | |||
| self.labels = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||
| ground_truths = inputs[label_name] | |||
| eval_results = outputs[label_name] | |||
| for key in [ | |||
| OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | |||
| OutputKeys.LABELS, OutputKeys.SCORES | |||
| ]: | |||
| if key in outputs and outputs[key] is not None: | |||
| eval_results = outputs[key] | |||
| break | |||
| assert type(ground_truths) == type(eval_results) | |||
| if isinstance(ground_truths, list): | |||
| self.preds.extend(eval_results) | |||
| self.labels.extend(ground_truths) | |||
| elif isinstance(ground_truths, np.ndarray): | |||
| self.preds.extend(eval_results.tolist()) | |||
| self.labels.extend(ground_truths.tolist()) | |||
| else: | |||
| raise Exception('only support list or np.ndarray') | |||
| def evaluate(self): | |||
| assert len(self.preds) == len(self.labels) | |||
| return { | |||
| MetricKeys.NED: (np.asarray([ | |||
| 1.0 - NedMetric._distance(pred, ref) | |||
| for pred, ref in zip(self.preds, self.labels) | |||
| ])).mean().item() | |||
| } | |||
| @staticmethod | |||
| def _distance(pred, ref): | |||
| if pred is None or ref is None: | |||
| raise TypeError('Argument (pred or ref) is NoneType.') | |||
| if pred == ref: | |||
| return 0.0 | |||
| if len(pred) == 0: | |||
| return len(ref) | |||
| if len(ref) == 0: | |||
| return len(pred) | |||
| m_len = max(len(pred), len(ref)) | |||
| if m_len == 0: | |||
| return 0.0 | |||
| def levenshtein(s0, s1): | |||
| v0 = [0] * (len(s1) + 1) | |||
| v1 = [0] * (len(s1) + 1) | |||
| for i in range(len(v0)): | |||
| v0[i] = i | |||
| for i in range(len(s0)): | |||
| v1[0] = i + 1 | |||
| for j in range(len(s1)): | |||
| cost = 1 | |||
| if s0[i] == s1[j]: | |||
| cost = 0 | |||
| v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost) | |||
| v0, v1 = v1, v0 | |||
| return v0[len(s1)] | |||
| return levenshtein(pred, ref) / m_len | |||
| @@ -91,8 +91,24 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
| ]) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = data['image'] if isinstance( | |||
| data['image'], Image.Image) else load_image(data['image']) | |||
| if self.mode == ModeKeys.TRAIN: | |||
| return self._build_train_sample(data) | |||
| else: | |||
| return self._build_infer_sample(data) | |||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| sample = self._build_infer_sample(data) | |||
| target = data[self.column_map['text']] | |||
| target = target.translate(self.transtab).strip() | |||
| target_token_list = target.strip().split() | |||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | |||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||
| sample['prev_output_tokens'] = torch.cat( | |||
| [self.bos_item, sample['target'][:-1]]) | |||
| return sample | |||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image = self.get_img_pil(data[self.column_map['image']]) | |||
| patch_image = self.patch_resize_transform(image) | |||
| prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') | |||
| inputs = self.tokenize_text(prompt) | |||
| @@ -102,4 +118,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||
| 'patch_image': patch_image, | |||
| 'patch_mask': torch.tensor([True]) | |||
| } | |||
| if 'text' in self.column_map and self.column_map['text'] in data: | |||
| sample['label'] = data[self.column_map['text']] | |||
| return sample | |||
| @@ -129,9 +129,7 @@ class OFATrainer(EpochBasedTrainer): | |||
| def train_step(self, model, inputs): | |||
| model.train() | |||
| model_outputs = model.forward(inputs) | |||
| loss, sample_size, logging_output = self.criterion( | |||
| model_outputs, inputs) | |||
| loss, sample_size, logging_output = self.criterion(model, inputs) | |||
| train_outputs = {'loss': loss} | |||
| # add model output info to log | |||
| if 'log_vars' not in train_outputs: | |||
| @@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| self.padding_idx = args.tokenizer.pad_token_id | |||
| self.args = args | |||
| def forward(self, output, sample, update_num=0, reduce=True): | |||
| def forward(self, model, sample, update_num=0, reduce=True): | |||
| """Compute the loss for the given sample. | |||
| Returns a tuple with three elements: | |||
| @@ -131,11 +131,16 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| 2) the sample size, which is used as the denominator for the gradient | |||
| 3) logging outputs to display while training | |||
| """ | |||
| if 'labels' in sample: | |||
| del sample['labels'] | |||
| if 'samples' in sample: | |||
| del sample['samples'] | |||
| if self.use_rdrop: | |||
| construct_rdrop_sample(sample) | |||
| output = model.model(**sample['net_input']) | |||
| loss, nll_loss, ntokens = self.compute_loss( | |||
| output, sample, update_num, reduce=reduce) | |||
| output.logits, sample, update_num, reduce=reduce) | |||
| sample_size = ( | |||
| sample['target'].size(0) if self.sentence_avg else ntokens) | |||
| logging_output = { | |||
| @@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| } | |||
| return loss, sample_size, logging_output | |||
| def get_lprobs_and_target(self, net_output, sample): | |||
| def get_lprobs_and_target(self, logits, sample): | |||
| conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ | |||
| 'conf'] is not None else 1 | |||
| constraint_masks = None | |||
| if 'constraint_masks' in sample and sample[ | |||
| 'constraint_masks'] is not None: | |||
| constraint_masks = sample['constraint_masks'] | |||
| net_output[0].masked_fill_(~constraint_masks, -math.inf) | |||
| logits.masked_fill_(~constraint_masks, -math.inf) | |||
| if self.constraint_start is not None and self.constraint_end is not None: | |||
| net_output[0][:, :, 4:self.constraint_start] = -math.inf | |||
| net_output[0][:, :, self.constraint_end:] = -math.inf | |||
| lprobs = F.log_softmax( | |||
| net_output[0], dim=-1, dtype=torch.float32) * conf | |||
| logits[:, :, 4:self.constraint_start] = -math.inf | |||
| logits[:, :, self.constraint_end:] = -math.inf | |||
| lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf | |||
| target = sample['target'] | |||
| if self.ignore_prefix_size > 0: | |||
| lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() | |||
| @@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| return lprobs.view(-1, | |||
| lprobs.size(-1)), target.view(-1), constraint_masks | |||
| def compute_loss(self, net_output, sample, update_num, reduce=True): | |||
| def compute_loss(self, logits, sample, update_num, reduce=True): | |||
| lprobs, target, constraint_masks = self.get_lprobs_and_target( | |||
| net_output, sample) | |||
| logits, sample) | |||
| if constraint_masks is not None: | |||
| constraint_masks = constraint_masks[target != self.padding_idx] | |||
| lprobs = lprobs[target != self.padding_idx] | |||
| @@ -5,10 +5,10 @@ import unittest | |||
| import json | |||
| from modelscope.metainfo import Metrics, Trainers | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.constant import DownloadMode, ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -17,26 +17,27 @@ class TestOfaTrainer(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.finetune_cfg = \ | |||
| {'framework': 'pytorch', | |||
| 'task': 'image-captioning', | |||
| 'task': 'ocr-recognition', | |||
| 'model': {'type': 'ofa', | |||
| 'beam_search': {'beam_size': 5, | |||
| 'max_len_b': 16, | |||
| 'max_len_b': 64, | |||
| 'min_len': 1, | |||
| 'no_repeat_ngram_size': 0}, | |||
| 'seed': 7, | |||
| 'max_src_length': 256, | |||
| 'language': 'en', | |||
| 'max_src_length': 128, | |||
| 'language': 'zh', | |||
| 'gen_type': 'generation', | |||
| 'patch_image_size': 480, | |||
| 'is_document': False, | |||
| 'max_image_size': 480, | |||
| 'imagenet_default_mean_and_std': False}, | |||
| 'pipeline': {'type': 'image-captioning'}, | |||
| 'dataset': {'column_map': {'text': 'caption'}}, | |||
| 'train': {'work_dir': 'work/ckpts/caption', | |||
| 'pipeline': {'type': 'ofa-ocr-recognition'}, | |||
| 'dataset': {'column_map': {'text': 'label'}}, | |||
| 'train': {'work_dir': 'work/ckpts/recognition', | |||
| # 'launcher': 'pytorch', | |||
| 'max_epochs': 1, | |||
| 'use_fp16': True, | |||
| 'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0}, | |||
| 'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | |||
| 'lr_scheduler': {'name': 'polynomial_decay', | |||
| 'warmup_proportion': 0.01, | |||
| 'lr_end': 1e-07}, | |||
| @@ -57,47 +58,48 @@ class TestOfaTrainer(unittest.TestCase): | |||
| 'report_accuracy': False, | |||
| 'sample_patch_num': 196, | |||
| 'sentence_avg': False, | |||
| 'use_rdrop': False}, | |||
| 'use_rdrop': True}, | |||
| 'hooks': [{'type': 'BestCkptSaverHook', | |||
| 'metric_key': 'bleu-4', | |||
| 'metric_key': 'accuracy', | |||
| 'interval': 100}, | |||
| {'type': 'TextLoggerHook', 'interval': 1}, | |||
| {'type': 'IterTimerHook'}, | |||
| {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, | |||
| 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | |||
| 'metrics': [{'type': 'bleu', | |||
| 'eval_tokenized_bleu': False, | |||
| 'ref_name': 'labels', | |||
| 'hyp_name': 'caption'}]}, | |||
| 'metrics': [{'type': 'accuracy'}]}, | |||
| 'preprocessor': []} | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer_std(self): | |||
| WORKSPACE = './workspace/ckpts/caption' | |||
| WORKSPACE = './workspace/ckpts/recognition' | |||
| os.makedirs(WORKSPACE, exist_ok=True) | |||
| config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||
| with open(config_file, 'w') as writer: | |||
| json.dump(self.finetune_cfg, writer) | |||
| pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' | |||
| pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' | |||
| args = dict( | |||
| model=pretrained_model, | |||
| work_dir=WORKSPACE, | |||
| train_dataset=MsDataset.load( | |||
| 'coco_2014_caption', | |||
| 'ocr_fudanvi_zh', | |||
| subset_name='scene', | |||
| namespace='modelscope', | |||
| split='train[:20]'), | |||
| split='train[:200]', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||
| eval_dataset=MsDataset.load( | |||
| 'coco_2014_caption', | |||
| 'ocr_fudanvi_zh', | |||
| subset_name='scene', | |||
| namespace='modelscope', | |||
| split='validation[:10]'), | |||
| metrics=[Metrics.BLEU], | |||
| split='test[:20]', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||
| cfg_file=config_file) | |||
| trainer = build_trainer(name=Trainers.ofa, default_args=args) | |||
| trainer.train() | |||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | |||
| os.listdir(os.path.join(WORKSPACE, 'output'))) | |||
| self.assertIn( | |||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||
| os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) | |||
| shutil.rmtree(WORKSPACE) | |||