Browse Source

caption finetune done, add belu

master
行嗔 3 years ago
parent
commit
dbf022efe8
9 changed files with 58 additions and 8 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +4
    -0
      modelscope/metrics/__init__.py
  3. +1
    -1
      modelscope/metrics/accuracy_metric.py
  4. +42
    -0
      modelscope/metrics/bleu_metric.py
  5. +0
    -2
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  6. +2
    -0
      modelscope/preprocessors/ofa/image_captioning.py
  7. +0
    -1
      modelscope/trainers/hooks/optimizer/torch_optimizer_hook.py
  8. +1
    -0
      requirements/multi-modal.txt
  9. +5
    -4
      tests/trainers/test_ofa_trainer.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -334,6 +334,9 @@ class Metrics(object):
accuracy = 'accuracy' accuracy = 'accuracy'
audio_noise_metric = 'audio-noise-metric' audio_noise_metric = 'audio-noise-metric'


# text gen
bleu = 'bleu'

# metrics for image denoise task # metrics for image denoise task
image_denoise_metric = 'image-denoise-metric' image_denoise_metric = 'image-denoise-metric'




+ 4
- 0
modelscope/metrics/__init__.py View File

@@ -17,6 +17,8 @@ if TYPE_CHECKING:
from .token_classification_metric import TokenClassificationMetric from .token_classification_metric import TokenClassificationMetric
from .video_summarization_metric import VideoSummarizationMetric from .video_summarization_metric import VideoSummarizationMetric
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
from .accuracy_metric import AccuracyMetric
from .bleu_metric import BleuMetric


else: else:
_import_structure = { _import_structure = {
@@ -34,6 +36,8 @@ else:
'token_classification_metric': ['TokenClassificationMetric'], 'token_classification_metric': ['TokenClassificationMetric'],
'video_summarization_metric': ['VideoSummarizationMetric'], 'video_summarization_metric': ['VideoSummarizationMetric'],
'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'],
'accuracy_metric': ['AccuracyMetric'],
'bleu_metric': ['BleuMetric'],
} }


import sys import sys


+ 1
- 1
modelscope/metrics/accuracy_metric.py View File

@@ -11,7 +11,7 @@ from .builder import METRICS, MetricKeys


@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) @METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy)
class AccuracyMetric(Metric): class AccuracyMetric(Metric):
"""The metric computation class for sequence classification classes.
"""The metric computation class for classification classes.


This metric class calculates accuracy for the whole input batches. This metric class calculates accuracy for the whole input batches.
""" """


+ 42
- 0
modelscope/metrics/bleu_metric.py View File

@@ -0,0 +1,42 @@
from itertools import zip_longest
from typing import Dict

import sacrebleu

from modelscope.metainfo import Metrics
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys

EVAL_BLEU_ORDER = 4


@METRICS.register_module(group_key=default_group, module_name=Metrics.bleu)
class BleuMetric(Metric):
"""The metric computation bleu for text generation classes.

This metric class calculates accuracy for the whole input batches.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False)
self.hyp_name = kwargs.get('hyp_name', 'hyp')
self.ref_name = kwargs.get('ref_name', 'ref')
self.refs = list()
self.hyps = list()

def add(self, outputs: Dict, inputs: Dict):
self.refs.extend(inputs[self.ref_name])
self.hyps.extend(outputs[self.hyp_name])

def evaluate(self):
if self.eval_tokenized_bleu:
bleu = sacrebleu.corpus_bleu(
self.hyps, list(zip_longest(*self.refs)), tokenize='none')
else:
bleu = sacrebleu.corpus_bleu(self.hyps,
list(zip_longest(*self.refs)))
return {
MetricKeys.BLEU_4: bleu.score,
}

+ 0
- 2
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -183,8 +183,6 @@ class OfaForAllTasks(TorchModel):
encoder_input[key] = input['net_input'][key] encoder_input[key] = input['net_input'][key]
encoder_out = self.model.encoder(**encoder_input) encoder_out = self.model.encoder(**encoder_input)
valid_result = [] valid_result = []
import pdb
pdb.set_trace()
for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l):
valid_size = len(val_ans) valid_size = len(val_ans)
valid_tgt_items = [ valid_tgt_items = [


+ 2
- 0
modelscope/preprocessors/ofa/image_captioning.py View File

@@ -66,4 +66,6 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
'patch_image': patch_image, 'patch_image': patch_image,
'patch_mask': torch.tensor([True]) 'patch_mask': torch.tensor([True])
} }
if 'text' in data:
sample['label'] = data['text']
return sample return sample

+ 0
- 1
modelscope/trainers/hooks/optimizer/torch_optimizer_hook.py View File

@@ -79,6 +79,5 @@ class TorchAMPOptimizerHook(OptimizerHook):
self.scaler.step(trainer.optimizer) self.scaler.step(trainer.optimizer)
self.scaler.update(self._scale_update_param) self.scaler.update(self._scale_update_param)
trainer.optimizer.zero_grad() trainer.optimizer.zero_grad()
print('xcxcxcxcxc: optimizer step')


setattr(self._model, 'forward', self._ori_model_forward) setattr(self._model, 'forward', self._ori_model_forward)

+ 1
- 0
requirements/multi-modal.txt View File

@@ -5,6 +5,7 @@ pycocotools>=2.0.4
# rough-score was just recently updated from 0.0.4 to 0.0.7 # rough-score was just recently updated from 0.0.4 to 0.0.7
# which introduced compatability issues that are being investigated # which introduced compatability issues that are being investigated
rouge_score<=0.0.4 rouge_score<=0.0.4
sacrebleu
taming-transformers-rom1504 taming-transformers-rom1504
timm timm
tokenizers tokenizers


+ 5
- 4
tests/trainers/test_ofa_trainer.py View File

@@ -9,13 +9,14 @@ from modelscope.utils.test_utils import test_level


class TestOfaTrainer(unittest.TestCase): class TestOfaTrainer(unittest.TestCase):


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer(self): def test_trainer(self):
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt'
self.trainer = OFATrainer(model_id, launcher='pytorch')
model_id = 'damo/ofa_image-caption_coco_huge_en'
self.trainer = OFATrainer(model_id)
os.makedirs(self.trainer.work_dir, exist_ok=True)
self.trainer.train() self.trainer.train()
if os.path.exists(self.trainer.work_dir): if os.path.exists(self.trainer.work_dir):
pass
shutil.rmtree(self.trainer.work_dir)




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save