Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9892528master
| @@ -198,6 +198,9 @@ class Trainers(object): | |||
| nlp_base_trainer = 'nlp-base-trainer' | |||
| nlp_veco_trainer = 'nlp-veco-trainer' | |||
| # audio trainers | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| class Preprocessors(object): | |||
| """ Names for different preprocessor. | |||
| @@ -254,6 +257,7 @@ class Metrics(object): | |||
| # accuracy | |||
| accuracy = 'accuracy' | |||
| audio_noise_metric = 'audio-noise-metric' | |||
| # metrics for image denoise task | |||
| image_denoise_metric = 'image-denoise-metric' | |||
| @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .audio_noise_metric import AudioNoiseMetric | |||
| from .base import Metric | |||
| from .builder import METRICS, build_metric, task_default_metrics | |||
| from .image_color_enhance_metric import ImageColorEnhanceMetric | |||
| @@ -18,6 +19,7 @@ if TYPE_CHECKING: | |||
| else: | |||
| _import_structure = { | |||
| 'audio_noise_metric': ['AudioNoiseMetric'], | |||
| 'base': ['Metric'], | |||
| 'builder': ['METRICS', 'build_metric', 'task_default_metrics'], | |||
| 'image_color_enhance_metric': ['ImageColorEnhanceMetric'], | |||
| @@ -0,0 +1,38 @@ | |||
| from typing import Dict | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.metrics.base import Metric | |||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||
| from modelscope.utils.registry import default_group | |||
| @METRICS.register_module( | |||
| group_key=default_group, module_name=Metrics.audio_noise_metric) | |||
| class AudioNoiseMetric(Metric): | |||
| """ | |||
| The metric computation class for acoustic noise suppression task. | |||
| """ | |||
| def __init__(self): | |||
| self.loss = [] | |||
| self.amp_loss = [] | |||
| self.phase_loss = [] | |||
| self.sisnr = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| self.loss.append(outputs['loss'].data.cpu()) | |||
| self.amp_loss.append(outputs['amp_loss'].data.cpu()) | |||
| self.phase_loss.append(outputs['phase_loss'].data.cpu()) | |||
| self.sisnr.append(outputs['sisnr'].data.cpu()) | |||
| def evaluate(self): | |||
| avg_loss = sum(self.loss) / len(self.loss) | |||
| avg_sisnr = sum(self.sisnr) / len(self.sisnr) | |||
| avg_amp = sum(self.amp_loss) / len(self.amp_loss) | |||
| avg_phase = sum(self.phase_loss) / len(self.phase_loss) | |||
| total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr | |||
| return { | |||
| 'total_loss': total_loss.item(), | |||
| 'avg_sisnr': avg_sisnr.item(), | |||
| MetricKeys.AVERAGE_LOSS: avg_loss.item() | |||
| } | |||
| @@ -16,6 +16,7 @@ class MetricKeys(object): | |||
| RECALL = 'recall' | |||
| PSNR = 'psnr' | |||
| SSIM = 'ssim' | |||
| AVERAGE_LOSS = 'avg_loss' | |||
| FScore = 'fscore' | |||
| @@ -71,32 +71,41 @@ class FRCRNModel(TorchModel): | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| kwargs.pop('device') | |||
| self.model = FRCRN(*args, **kwargs) | |||
| model_bin_file = os.path.join(model_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| if os.path.exists(model_bin_file): | |||
| checkpoint = torch.load(model_bin_file) | |||
| self.model.load_state_dict(checkpoint, strict=False) | |||
| checkpoint = torch.load( | |||
| model_bin_file, map_location=torch.device('cpu')) | |||
| if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: | |||
| self.model.load_state_dict( | |||
| checkpoint['state_dict'], strict=False) | |||
| else: | |||
| self.model.load_state_dict(checkpoint, strict=False) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| output = self.model.forward(input) | |||
| return { | |||
| 'spec_l1': output[0], | |||
| 'wav_l1': output[1], | |||
| 'mask_l1': output[2], | |||
| 'spec_l2': output[3], | |||
| 'wav_l2': output[4], | |||
| 'mask_l2': output[5] | |||
| result_list = self.model.forward(input['noisy']) | |||
| output = { | |||
| 'spec_l1': result_list[0], | |||
| 'wav_l1': result_list[1], | |||
| 'mask_l1': result_list[2], | |||
| 'spec_l2': result_list[3], | |||
| 'wav_l2': result_list[4], | |||
| 'mask_l2': result_list[5] | |||
| } | |||
| def to(self, *args, **kwargs): | |||
| self.model = self.model.to(*args, **kwargs) | |||
| return self | |||
| def eval(self): | |||
| self.model = self.model.train(False) | |||
| return self | |||
| if 'clean' in input: | |||
| mix_result = self.model.loss( | |||
| input['noisy'], input['clean'], result_list, mode='Mix') | |||
| output.update(mix_result) | |||
| sisnr_result = self.model.loss( | |||
| input['noisy'], input['clean'], result_list, mode='SiSNR') | |||
| output.update(sisnr_result) | |||
| # logger hooker will use items under 'log_vars' | |||
| output['log_vars'] = {k: mix_result[k].item() for k in mix_result} | |||
| output['log_vars'].update( | |||
| {k: sisnr_result[k].item() | |||
| for k in sisnr_result}) | |||
| return output | |||
| class FRCRN(nn.Module): | |||
| @@ -111,7 +120,8 @@ class FRCRN(nn.Module): | |||
| win_len=400, | |||
| win_inc=100, | |||
| fft_len=512, | |||
| win_type='hanning'): | |||
| win_type='hanning', | |||
| **kwargs): | |||
| r""" | |||
| Args: | |||
| complex: Whether to use complex networks. | |||
| @@ -237,7 +247,7 @@ class FRCRN(nn.Module): | |||
| if count != 3: | |||
| loss = self.loss_1layer(noisy, est_spec, est_wav, labels, | |||
| est_mask, mode) | |||
| return loss | |||
| return dict(sisnr=loss) | |||
| elif mode == 'Mix': | |||
| count = 0 | |||
| @@ -252,7 +262,7 @@ class FRCRN(nn.Module): | |||
| amp_loss, phase_loss, SiSNR_loss = self.loss_1layer( | |||
| noisy, est_spec, est_wav, labels, est_mask, mode) | |||
| loss = amp_loss + phase_loss + SiSNR_loss | |||
| return loss, amp_loss, phase_loss | |||
| return dict(loss=loss, amp_loss=amp_loss, phase_loss=phase_loss) | |||
| def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'): | |||
| r""" Compute the loss by mode | |||
| @@ -10,21 +10,10 @@ from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.audio.audio_utils import audio_norm | |||
| from modelscope.utils.constant import Tasks | |||
| def audio_norm(x): | |||
| rms = (x**2).mean()**0.5 | |||
| scalar = 10**(-25 / 20) / rms | |||
| x = x * scalar | |||
| pow_x = x**2 | |||
| avg_pow_x = pow_x.mean() | |||
| rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 | |||
| scalarx = 10**(-25 / 20) / rmsx | |||
| x = x * scalarx | |||
| return x | |||
| @PIPELINES.register_module( | |||
| Tasks.acoustic_noise_suppression, | |||
| module_name=Pipelines.speech_frcrn_ans_cirm_16k) | |||
| @@ -98,7 +87,8 @@ class ANSPipeline(Pipeline): | |||
| current_idx = 0 | |||
| while current_idx + window <= t: | |||
| print('current_idx: {}'.format(current_idx)) | |||
| tmp_input = ndarray[:, current_idx:current_idx + window] | |||
| tmp_input = dict(noisy=ndarray[:, current_idx:current_idx | |||
| + window]) | |||
| tmp_output = self.model( | |||
| tmp_input, )['wav_l2'][0].cpu().numpy() | |||
| end_index = current_idx + window - give_up_length | |||
| @@ -111,7 +101,8 @@ class ANSPipeline(Pipeline): | |||
| give_up_length:-give_up_length] | |||
| current_idx += stride | |||
| else: | |||
| outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy() | |||
| outputs = self.model( | |||
| dict(noisy=ndarray))['wav_l2'][0].cpu().numpy() | |||
| outputs = (outputs[:nsamples] * 32768).astype(np.int16).tobytes() | |||
| return {OutputKeys.OUTPUT_PCM: outputs} | |||
| @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .audio.ans_trainer import ANSTrainer | |||
| from .base import DummyTrainer | |||
| from .builder import build_trainer | |||
| from .cv import (ImageInstanceSegmentationTrainer, | |||
| @@ -15,6 +16,7 @@ if TYPE_CHECKING: | |||
| else: | |||
| _import_structure = { | |||
| 'audio.ans_trainer': ['ANSTrainer'], | |||
| 'base': ['DummyTrainer'], | |||
| 'builder': ['build_trainer'], | |||
| 'cv': [ | |||
| @@ -0,0 +1,57 @@ | |||
| import time | |||
| from typing import List, Optional, Union | |||
| from datasets import Dataset | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.preprocessors import Preprocessor | |||
| from modelscope.trainers import EpochBasedTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.utils.constant import TrainerStages | |||
| from modelscope.utils.data_utils import to_device | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @TRAINERS.register_module(module_name=Trainers.speech_frcrn_ans_cirm_16k) | |||
| class ANSTrainer(EpochBasedTrainer): | |||
| """ | |||
| A trainer is used for acoustic noise suppression. | |||
| Override train_loop() to use dataset just one time. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| def train_loop(self, data_loader): | |||
| """ | |||
| Update epoch by step number, based on super method. | |||
| """ | |||
| self.invoke_hook(TrainerStages.before_run) | |||
| self._epoch = 0 | |||
| kwargs = {} | |||
| self.model.train() | |||
| enumerated = enumerate(data_loader) | |||
| for _ in range(self._epoch, self._max_epochs): | |||
| self.invoke_hook(TrainerStages.before_train_epoch) | |||
| self._inner_iter = 0 | |||
| for i, data_batch in enumerated: | |||
| data_batch = to_device(data_batch, self.device) | |||
| self.data_batch = data_batch | |||
| self._inner_iter += 1 | |||
| self.invoke_hook(TrainerStages.before_train_iter) | |||
| self.train_step(self.model, data_batch, **kwargs) | |||
| self.invoke_hook(TrainerStages.after_train_iter) | |||
| del self.data_batch | |||
| self._iter += 1 | |||
| if self._inner_iter >= self.iters_per_epoch: | |||
| break | |||
| self.invoke_hook(TrainerStages.after_train_epoch) | |||
| self._epoch += 1 | |||
| self.invoke_hook(TrainerStages.after_run) | |||
| def prediction_step(self, model, inputs): | |||
| pass | |||
| @@ -0,0 +1,35 @@ | |||
| import numpy as np | |||
| SEGMENT_LENGTH_TRAIN = 16000 | |||
| def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN): | |||
| """ | |||
| Dataset mapping function to split one audio into segments. | |||
| It only works in batch mode. | |||
| """ | |||
| noisy_arrays = [] | |||
| for x in batch['noisy']: | |||
| length = len(x['array']) | |||
| noisy = np.array(x['array']) | |||
| for offset in range(segment_length, length, segment_length): | |||
| noisy_arrays.append(noisy[offset - segment_length:offset]) | |||
| clean_arrays = [] | |||
| for x in batch['clean']: | |||
| length = len(x['array']) | |||
| clean = np.array(x['array']) | |||
| for offset in range(segment_length, length, segment_length): | |||
| clean_arrays.append(clean[offset - segment_length:offset]) | |||
| return {'noisy': noisy_arrays, 'clean': clean_arrays} | |||
| def audio_norm(x): | |||
| rms = (x**2).mean()**0.5 | |||
| scalar = 10**(-25 / 20) / rms | |||
| x = x * scalar | |||
| pow_x = x**2 | |||
| avg_pow_x = pow_x.mean() | |||
| rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 | |||
| scalarx = 10**(-25 / 20) / rmsx | |||
| x = x * scalarx | |||
| return x | |||
| @@ -0,0 +1,56 @@ | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from functools import partial | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.audio.audio_utils import to_segment | |||
| from modelscope.utils.test_utils import test_level | |||
| SEGMENT_LENGTH_TEST = 640 | |||
| class TestANSTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| self.model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||
| hf_ds = MsDataset.load( | |||
| 'ICASSP_2021_DNS_Challenge', split='test').to_hf_dataset() | |||
| mapped_ds = hf_ds.map( | |||
| partial(to_segment, segment_length=SEGMENT_LENGTH_TEST), | |||
| remove_columns=['duration'], | |||
| batched=True, | |||
| batch_size=2) | |||
| self.dataset = MsDataset.from_hf_dataset(mapped_ds) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| model_revision='beta', | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| max_epochs=2, | |||
| train_iters_per_epoch=2, | |||
| val_iters_per_epoch=1, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer( | |||
| Trainers.speech_frcrn_ans_cirm_16k, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(2): | |||
| self.assertIn(f'epoch_{i + 1}.pth', results_files) | |||