| @@ -334,6 +334,9 @@ class Metrics(object): | |||||
| accuracy = 'accuracy' | accuracy = 'accuracy' | ||||
| audio_noise_metric = 'audio-noise-metric' | audio_noise_metric = 'audio-noise-metric' | ||||
| # text gen | |||||
| bleu = 'bleu' | |||||
| # metrics for image denoise task | # metrics for image denoise task | ||||
| image_denoise_metric = 'image-denoise-metric' | image_denoise_metric = 'image-denoise-metric' | ||||
| @@ -17,6 +17,8 @@ if TYPE_CHECKING: | |||||
| from .token_classification_metric import TokenClassificationMetric | from .token_classification_metric import TokenClassificationMetric | ||||
| from .video_summarization_metric import VideoSummarizationMetric | from .video_summarization_metric import VideoSummarizationMetric | ||||
| from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | ||||
| from .accuracy_metric import AccuracyMetric | |||||
| from .bleu_metric import BleuMetric | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -34,6 +36,8 @@ else: | |||||
| 'token_classification_metric': ['TokenClassificationMetric'], | 'token_classification_metric': ['TokenClassificationMetric'], | ||||
| 'video_summarization_metric': ['VideoSummarizationMetric'], | 'video_summarization_metric': ['VideoSummarizationMetric'], | ||||
| 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | ||||
| 'accuracy_metric': ['AccuracyMetric'], | |||||
| 'bleu_metric': ['BleuMetric'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -11,7 +11,7 @@ from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) | @METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) | ||||
| class AccuracyMetric(Metric): | class AccuracyMetric(Metric): | ||||
| """The metric computation class for sequence classification classes. | |||||
| """The metric computation class for classification classes. | |||||
| This metric class calculates accuracy for the whole input batches. | This metric class calculates accuracy for the whole input batches. | ||||
| """ | """ | ||||
| @@ -0,0 +1,42 @@ | |||||
| from itertools import zip_longest | |||||
| from typing import Dict | |||||
| import sacrebleu | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| EVAL_BLEU_ORDER = 4 | |||||
| @METRICS.register_module(group_key=default_group, module_name=Metrics.bleu) | |||||
| class BleuMetric(Metric): | |||||
| """The metric computation bleu for text generation classes. | |||||
| This metric class calculates accuracy for the whole input batches. | |||||
| """ | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False) | |||||
| self.hyp_name = kwargs.get('hyp_name', 'hyp') | |||||
| self.ref_name = kwargs.get('ref_name', 'ref') | |||||
| self.refs = list() | |||||
| self.hyps = list() | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| self.refs.extend(inputs[self.ref_name]) | |||||
| self.hyps.extend(outputs[self.hyp_name]) | |||||
| def evaluate(self): | |||||
| if self.eval_tokenized_bleu: | |||||
| bleu = sacrebleu.corpus_bleu( | |||||
| self.hyps, list(zip_longest(*self.refs)), tokenize='none') | |||||
| else: | |||||
| bleu = sacrebleu.corpus_bleu(self.hyps, | |||||
| list(zip_longest(*self.refs))) | |||||
| return { | |||||
| MetricKeys.BLEU_4: bleu.score, | |||||
| } | |||||
| @@ -183,8 +183,6 @@ class OfaForAllTasks(TorchModel): | |||||
| encoder_input[key] = input['net_input'][key] | encoder_input[key] = input['net_input'][key] | ||||
| encoder_out = self.model.encoder(**encoder_input) | encoder_out = self.model.encoder(**encoder_input) | ||||
| valid_result = [] | valid_result = [] | ||||
| import pdb | |||||
| pdb.set_trace() | |||||
| for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): | for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): | ||||
| valid_size = len(val_ans) | valid_size = len(val_ans) | ||||
| valid_tgt_items = [ | valid_tgt_items = [ | ||||
| @@ -66,4 +66,6 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
| 'patch_image': patch_image, | 'patch_image': patch_image, | ||||
| 'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
| } | } | ||||
| if 'text' in data: | |||||
| sample['label'] = data['text'] | |||||
| return sample | return sample | ||||
| @@ -79,6 +79,5 @@ class TorchAMPOptimizerHook(OptimizerHook): | |||||
| self.scaler.step(trainer.optimizer) | self.scaler.step(trainer.optimizer) | ||||
| self.scaler.update(self._scale_update_param) | self.scaler.update(self._scale_update_param) | ||||
| trainer.optimizer.zero_grad() | trainer.optimizer.zero_grad() | ||||
| print('xcxcxcxcxc: optimizer step') | |||||
| setattr(self._model, 'forward', self._ori_model_forward) | setattr(self._model, 'forward', self._ori_model_forward) | ||||
| @@ -5,6 +5,7 @@ pycocotools>=2.0.4 | |||||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | # rough-score was just recently updated from 0.0.4 to 0.0.7 | ||||
| # which introduced compatability issues that are being investigated | # which introduced compatability issues that are being investigated | ||||
| rouge_score<=0.0.4 | rouge_score<=0.0.4 | ||||
| sacrebleu | |||||
| taming-transformers-rom1504 | taming-transformers-rom1504 | ||||
| timm | timm | ||||
| tokenizers | tokenizers | ||||
| @@ -9,13 +9,14 @@ from modelscope.utils.test_utils import test_level | |||||
| class TestOfaTrainer(unittest.TestCase): | class TestOfaTrainer(unittest.TestCase): | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer(self): | def test_trainer(self): | ||||
| model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt' | |||||
| self.trainer = OFATrainer(model_id, launcher='pytorch') | |||||
| model_id = 'damo/ofa_image-caption_coco_huge_en' | |||||
| self.trainer = OFATrainer(model_id) | |||||
| os.makedirs(self.trainer.work_dir, exist_ok=True) | |||||
| self.trainer.train() | self.trainer.train() | ||||
| if os.path.exists(self.trainer.work_dir): | if os.path.exists(self.trainer.work_dir): | ||||
| pass | |||||
| shutil.rmtree(self.trainer.work_dir) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||