From c3a494e46d3fa80f8743dd7cd4123d79f5cb574a Mon Sep 17 00:00:00 2001 From: "shiyi.zxh" Date: Tue, 6 Dec 2022 20:58:49 +0800 Subject: [PATCH] [to #42322933] enable finetune of ofa-mmspeech Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10981972 --- .../models/multi_modal/ofa_for_all_tasks.py | 2 + modelscope/preprocessors/ofa/asr.py | 5 +- modelscope/preprocessors/ofa/base.py | 3 + modelscope/preprocessors/ofa/utils/collate.py | 4 + .../trainers/multi_modal/ofa/ofa_trainer.py | 29 ++--- .../multi_modal/ofa/ofa_trainer_utils.py | 29 ++++- tests/trainers/test_ofa_mmspeech_trainer.py | 108 ++++++++++++++++++ tests/trainers/test_ofa_trainer.py | 3 +- 8 files changed, 154 insertions(+), 29 deletions(-) create mode 100644 tests/trainers/test_ofa_mmspeech_trainer.py diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 1ae746b7..3a35be58 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -41,6 +41,8 @@ __all__ = ['OfaForAllTasks'] class OfaForAllTasks(TorchModel): def __init__(self, model_dir, *args, **kwargs): + if os.path.exists(model_dir): + model_dir = os.path.abspath(model_dir) super().__init__(model_dir=model_dir, *args, **kwargs) self.cfg = Config.from_file( osp.join(model_dir, ModelFile.CONFIGURATION)) diff --git a/modelscope/preprocessors/ofa/asr.py b/modelscope/preprocessors/ofa/asr.py index f4ae2097..5d36b829 100644 --- a/modelscope/preprocessors/ofa/asr.py +++ b/modelscope/preprocessors/ofa/asr.py @@ -80,10 +80,11 @@ class OfaASRPreprocessor(OfaBasePreprocessor): target = ' '.join(target_token_list[:self.max_tgt_length]) sample['target'] = self.tokenize_text(target, add_bos=False) - phone_item = self.to_phone(target) - 3 + phone_item = self.to_phone(target) + 1 phone_mask = torch.tensor([False]) - sample['phone_item'] = phone_item + sample['phone_item'] = phone_item + 3 + sample['phone_target'] = phone_item sample['phone_mask'] = phone_mask sample['prev_output_tokens'] = torch.cat( diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py index 4faa22fe..c2b61c5e 100644 --- a/modelscope/preprocessors/ofa/base.py +++ b/modelscope/preprocessors/ofa/base.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import io +import os import re import string from os import path as osp @@ -32,6 +33,8 @@ class OfaBasePreprocessor: self.cfg = cfg self.mode = mode self.language = self.cfg.model.get('language', 'en') + if os.path.exists(model_dir): + model_dir = os.path.abspath(model_dir) if self.language == 'en': tokenizer = OFATokenizer.from_pretrained(model_dir) elif self.language in ['zh', 'cn']: diff --git a/modelscope/preprocessors/ofa/utils/collate.py b/modelscope/preprocessors/ofa/utils/collate.py index 440ea9a0..b5dacd04 100644 --- a/modelscope/preprocessors/ofa/utils/collate.py +++ b/modelscope/preprocessors/ofa/utils/collate.py @@ -83,6 +83,10 @@ def collate_fn(samples, pad_idx, eos_idx): batch['net_input']['phone_items'] = merge('phone_item') batch['net_input']['phone_masks'] = torch.cat( [s['phone_mask'] for s in samples]) + if samples[0].get('phone_target', None) is not None: + batch['phone_target'] = merge('phone_target') + batch['phone_length'] = torch.tensor( + [s['phone_target'].size(0) for s in samples], dtype=torch.long) return batch diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index 1188fc46..f7801f09 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -2,8 +2,8 @@ import math import os -import shutil from functools import partial +from shutil import ignore_patterns from typing import Callable, Dict, Optional, Tuple, Union import torch @@ -23,9 +23,9 @@ from modelscope.trainers.optimizer.builder import build_optimizer from modelscope.trainers.parallel.utils import is_parallel from modelscope.utils.config import Config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, - Invoke, ModeKeys) + Invoke, ModeKeys, ModelFile) from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, - get_schedule) + get_schedule, recursive_overwrite) @TRAINERS.register_module(module_name=Trainers.ofa) @@ -58,23 +58,12 @@ class OFATrainer(EpochBasedTrainer): work_dir = cfg.train.work_dir else: work_dir = kwargs['work_dir'] - tokenizer_files = { - 'zh': [ - 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', - 'config.json', 'ans2label.json' - ], - 'en': [ - 'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json', - 'ans2label.json' - ], - } - for filename in tokenizer_files[cfg.model.get('language', 'en')]: - finetune_file = os.path.join(work_dir, filename) - pretrain_file = os.path.join(model_dir, filename) - if os.path.exists(finetune_file): - continue - if os.path.exists(pretrain_file): - shutil.copy(pretrain_file, finetune_file) + + os.makedirs(work_dir, exist_ok=True) + ignore_file_set = set() + ignore_file_set.add(ModelFile.CONFIGURATION) + recursive_overwrite( + model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set)) if preprocessor is None: preprocessor = { diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index c8cf6db5..ffd4cf78 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -3,6 +3,8 @@ # This source code is licensed under the Apache 2.0 license # found in the LICENSE file in the root directory. import math +import os +import shutil import numpy as np import torch @@ -11,6 +13,23 @@ import transformers from torch.nn.modules.loss import _Loss +def recursive_overwrite(src, dst, ignore=None): + if os.path.isdir(src): + if not os.path.isdir(dst): + os.makedirs(dst) + files = os.listdir(src) + if ignore is not None: + ignored = ignore(src, files) + else: + ignored = set() + for f in files: + if f not in ignored: + recursive_overwrite( + os.path.join(src, f), os.path.join(dst, f), ignore) + else: + shutil.copyfile(src, dst) + + def construct_rdrop_sample(x): if isinstance(x, dict): for key in x: @@ -211,17 +230,17 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): return loss, nll_loss, ntokens def compute_ctc_loss(self, model, output, sample): - lprobs = model.get_encoder_normalized_probs( + lprobs = model.model.get_encoder_normalized_probs( output, log_probs=True).contiguous() # (T, B, C) from the encoder non_padding_mask = ~output.encoder_padding_mask input_lengths = non_padding_mask.long().sum(-1) - target_lengths = sample['ctc_output_lengths'] + target_lengths = sample['phone_length'] pad_mask = torch.arange(target_lengths.max()).expand([ target_lengths.shape[0], -1 ]).to(target_lengths) < target_lengths.unsqueeze(1) - targets_flat = sample['ctc_outputs'].masked_select(pad_mask) + targets_flat = sample['phone_target'].masked_select(pad_mask) with torch.backends.cudnn.flags(enabled=False): loss = F.ctc_loss( @@ -229,12 +248,12 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): targets_flat, input_lengths, target_lengths, - blank=self.blank_idx, + blank=0, reduction='sum', zero_infinity=True, ) - return loss + return loss / lprobs.shape[1] def get_schedule(scheduler): diff --git a/tests/trainers/test_ofa_mmspeech_trainer.py b/tests/trainers/test_ofa_mmspeech_trainer.py new file mode 100644 index 00000000..2c4f6307 --- /dev/null +++ b/tests/trainers/test_ofa_mmspeech_trainer.py @@ -0,0 +1,108 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import unittest + +import json + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.test_utils import test_level + + +class TestMMSpeechTrainer(unittest.TestCase): + + def setUp(self) -> None: + self.finetune_cfg = \ + {'framework': 'pytorch', + 'task': 'auto-speech-recognition', + 'model': {'type': 'ofa', + 'beam_search': {'beam_size': 5, + 'max_len_b': 128, + 'min_len': 1, + 'no_repeat_ngram_size': 5, + 'constraint_range': '4,21134'}, + 'seed': 7, + 'max_src_length': 256, + 'language': 'zh', + 'gen_type': 'generation', + 'multimodal_type': 'mmspeech'}, + 'pipeline': {'type': 'ofa-asr'}, + 'n_frames_per_step': 1, + 'dataset': {'column_map': {'wav': 'Audio:FILE', 'text': 'Text:LABEL'}}, + 'train': {'work_dir': 'work/ckpts/asr_recognition', + # 'launcher': 'pytorch', + 'max_epochs': 1, + 'use_fp16': True, + 'dataloader': {'batch_size_per_gpu': 16, 'workers_per_gpu': 0}, + 'lr_scheduler': {'name': 'polynomial_decay', + 'warmup_proportion': 0.01, + 'lr_end': 1e-07}, + 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, + 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01}, + 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', + 'cumulative_iters': 1, + 'grad_clip': {'max_norm': 1.0, 'norm_type': 2}, + 'loss_keys': 'loss'}, + 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion', + 'constraint_range': '4,21134', + 'drop_worst_after': 0, + 'drop_worst_ratio': 0.0, + 'ignore_eos': False, + 'ignore_prefix_size': 0, + 'label_smoothing': 0.1, + 'reg_alpha': 1.0, + 'report_accuracy': False, + 'sample_patch_num': 196, + 'sentence_avg': True, + 'use_rdrop': False, + 'ctc_weight': 1.0}, + 'hooks': [{'type': 'BestCkptSaverHook', + 'metric_key': 'accuracy', + 'interval': 100}, + {'type': 'TextLoggerHook', 'interval': 1}, + {'type': 'IterTimerHook'}, + {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, + 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, + 'metrics': [{'type': 'accuracy'}]}, + 'preprocessor': []} + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_std(self): + WORKSPACE = './workspace/ckpts/asr_recognition' + os.makedirs(WORKSPACE, exist_ok=True) + config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + with open(config_file, 'w') as writer: + json.dump(self.finetune_cfg, writer) + + pretrained_model = 'damo/ofa_mmspeech_pretrain_base_zh' + + args = dict( + model=pretrained_model, + work_dir=WORKSPACE, + train_dataset=MsDataset.load( + 'aishell1_subset', + subset_name='default', + namespace='modelscope', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), + eval_dataset=MsDataset.load( + 'aishell1_subset', + subset_name='default', + namespace='modelscope', + split='test', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), + cfg_file=config_file) + trainer = build_trainer(name=Trainers.ofa, default_args=args) + trainer.train() + + self.assertIn( + ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) + shutil.rmtree(WORKSPACE) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 0516e569..ab2b8cc6 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -76,8 +76,7 @@ class TestOfaTrainer(unittest.TestCase): os.makedirs(WORKSPACE, exist_ok=True) config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) with open(config_file, 'w') as writer: - json.dump(self.finetune_cfg, writer) - + json.dump(self.finetune_cfg, writer, indent=4) pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' args = dict(