diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 58fd4f46..eab870ae 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -198,6 +198,9 @@ class Trainers(object): nlp_base_trainer = 'nlp-base-trainer' nlp_veco_trainer = 'nlp-veco-trainer' + # audio trainers + speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + class Preprocessors(object): """ Names for different preprocessor. @@ -254,6 +257,7 @@ class Metrics(object): # accuracy accuracy = 'accuracy' + audio_noise_metric = 'audio-noise-metric' # metrics for image denoise task image_denoise_metric = 'image-denoise-metric' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index d307f7c9..c74b475e 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .audio_noise_metric import AudioNoiseMetric from .base import Metric from .builder import METRICS, build_metric, task_default_metrics from .image_color_enhance_metric import ImageColorEnhanceMetric @@ -18,6 +19,7 @@ if TYPE_CHECKING: else: _import_structure = { + 'audio_noise_metric': ['AudioNoiseMetric'], 'base': ['Metric'], 'builder': ['METRICS', 'build_metric', 'task_default_metrics'], 'image_color_enhance_metric': ['ImageColorEnhanceMetric'], diff --git a/modelscope/metrics/audio_noise_metric.py b/modelscope/metrics/audio_noise_metric.py new file mode 100644 index 00000000..16c5261f --- /dev/null +++ b/modelscope/metrics/audio_noise_metric.py @@ -0,0 +1,38 @@ +from typing import Dict + +from modelscope.metainfo import Metrics +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.utils.registry import default_group + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.audio_noise_metric) +class AudioNoiseMetric(Metric): + """ + The metric computation class for acoustic noise suppression task. + """ + + def __init__(self): + self.loss = [] + self.amp_loss = [] + self.phase_loss = [] + self.sisnr = [] + + def add(self, outputs: Dict, inputs: Dict): + self.loss.append(outputs['loss'].data.cpu()) + self.amp_loss.append(outputs['amp_loss'].data.cpu()) + self.phase_loss.append(outputs['phase_loss'].data.cpu()) + self.sisnr.append(outputs['sisnr'].data.cpu()) + + def evaluate(self): + avg_loss = sum(self.loss) / len(self.loss) + avg_sisnr = sum(self.sisnr) / len(self.sisnr) + avg_amp = sum(self.amp_loss) / len(self.amp_loss) + avg_phase = sum(self.phase_loss) / len(self.phase_loss) + total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr + return { + 'total_loss': total_loss.item(), + 'avg_sisnr': avg_sisnr.item(), + MetricKeys.AVERAGE_LOSS: avg_loss.item() + } diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 9ba80a6c..869a1ab2 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -16,6 +16,7 @@ class MetricKeys(object): RECALL = 'recall' PSNR = 'psnr' SSIM = 'ssim' + AVERAGE_LOSS = 'avg_loss' FScore = 'fscore' diff --git a/modelscope/models/audio/ans/frcrn.py b/modelscope/models/audio/ans/frcrn.py index ba78ab74..59411fbe 100644 --- a/modelscope/models/audio/ans/frcrn.py +++ b/modelscope/models/audio/ans/frcrn.py @@ -71,32 +71,41 @@ class FRCRNModel(TorchModel): model_dir (str): the model path. """ super().__init__(model_dir, *args, **kwargs) - kwargs.pop('device') self.model = FRCRN(*args, **kwargs) model_bin_file = os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE) if os.path.exists(model_bin_file): - checkpoint = torch.load(model_bin_file) - self.model.load_state_dict(checkpoint, strict=False) + checkpoint = torch.load( + model_bin_file, map_location=torch.device('cpu')) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + self.model.load_state_dict( + checkpoint['state_dict'], strict=False) + else: + self.model.load_state_dict(checkpoint, strict=False) def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: - output = self.model.forward(input) - return { - 'spec_l1': output[0], - 'wav_l1': output[1], - 'mask_l1': output[2], - 'spec_l2': output[3], - 'wav_l2': output[4], - 'mask_l2': output[5] + result_list = self.model.forward(input['noisy']) + output = { + 'spec_l1': result_list[0], + 'wav_l1': result_list[1], + 'mask_l1': result_list[2], + 'spec_l2': result_list[3], + 'wav_l2': result_list[4], + 'mask_l2': result_list[5] } - - def to(self, *args, **kwargs): - self.model = self.model.to(*args, **kwargs) - return self - - def eval(self): - self.model = self.model.train(False) - return self + if 'clean' in input: + mix_result = self.model.loss( + input['noisy'], input['clean'], result_list, mode='Mix') + output.update(mix_result) + sisnr_result = self.model.loss( + input['noisy'], input['clean'], result_list, mode='SiSNR') + output.update(sisnr_result) + # logger hooker will use items under 'log_vars' + output['log_vars'] = {k: mix_result[k].item() for k in mix_result} + output['log_vars'].update( + {k: sisnr_result[k].item() + for k in sisnr_result}) + return output class FRCRN(nn.Module): @@ -111,7 +120,8 @@ class FRCRN(nn.Module): win_len=400, win_inc=100, fft_len=512, - win_type='hanning'): + win_type='hanning', + **kwargs): r""" Args: complex: Whether to use complex networks. @@ -237,7 +247,7 @@ class FRCRN(nn.Module): if count != 3: loss = self.loss_1layer(noisy, est_spec, est_wav, labels, est_mask, mode) - return loss + return dict(sisnr=loss) elif mode == 'Mix': count = 0 @@ -252,7 +262,7 @@ class FRCRN(nn.Module): amp_loss, phase_loss, SiSNR_loss = self.loss_1layer( noisy, est_spec, est_wav, labels, est_mask, mode) loss = amp_loss + phase_loss + SiSNR_loss - return loss, amp_loss, phase_loss + return dict(loss=loss, amp_loss=amp_loss, phase_loss=phase_loss) def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'): r""" Compute the loss by mode diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index 410a7cb5..5ed4d769 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -10,21 +10,10 @@ from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.audio.audio_utils import audio_norm from modelscope.utils.constant import Tasks -def audio_norm(x): - rms = (x**2).mean()**0.5 - scalar = 10**(-25 / 20) / rms - x = x * scalar - pow_x = x**2 - avg_pow_x = pow_x.mean() - rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 - scalarx = 10**(-25 / 20) / rmsx - x = x * scalarx - return x - - @PIPELINES.register_module( Tasks.acoustic_noise_suppression, module_name=Pipelines.speech_frcrn_ans_cirm_16k) @@ -98,7 +87,8 @@ class ANSPipeline(Pipeline): current_idx = 0 while current_idx + window <= t: print('current_idx: {}'.format(current_idx)) - tmp_input = ndarray[:, current_idx:current_idx + window] + tmp_input = dict(noisy=ndarray[:, current_idx:current_idx + + window]) tmp_output = self.model( tmp_input, )['wav_l2'][0].cpu().numpy() end_index = current_idx + window - give_up_length @@ -111,7 +101,8 @@ class ANSPipeline(Pipeline): give_up_length:-give_up_length] current_idx += stride else: - outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy() + outputs = self.model( + dict(noisy=ndarray))['wav_l2'][0].cpu().numpy() outputs = (outputs[:nsamples] * 32768).astype(np.int16).tobytes() return {OutputKeys.OUTPUT_PCM: outputs} diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index 17ed7f3c..32ff674f 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .audio.ans_trainer import ANSTrainer from .base import DummyTrainer from .builder import build_trainer from .cv import (ImageInstanceSegmentationTrainer, @@ -15,6 +16,7 @@ if TYPE_CHECKING: else: _import_structure = { + 'audio.ans_trainer': ['ANSTrainer'], 'base': ['DummyTrainer'], 'builder': ['build_trainer'], 'cv': [ diff --git a/modelscope/trainers/audio/__init__.py b/modelscope/trainers/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/trainers/audio/ans_trainer.py b/modelscope/trainers/audio/ans_trainer.py new file mode 100644 index 00000000..f782b836 --- /dev/null +++ b/modelscope/trainers/audio/ans_trainer.py @@ -0,0 +1,57 @@ +import time +from typing import List, Optional, Union + +from datasets import Dataset + +from modelscope.metainfo import Trainers +from modelscope.preprocessors import Preprocessor +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import TrainerStages +from modelscope.utils.data_utils import to_device +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@TRAINERS.register_module(module_name=Trainers.speech_frcrn_ans_cirm_16k) +class ANSTrainer(EpochBasedTrainer): + """ + A trainer is used for acoustic noise suppression. + Override train_loop() to use dataset just one time. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def train_loop(self, data_loader): + """ + Update epoch by step number, based on super method. + """ + self.invoke_hook(TrainerStages.before_run) + self._epoch = 0 + kwargs = {} + self.model.train() + enumerated = enumerate(data_loader) + for _ in range(self._epoch, self._max_epochs): + self.invoke_hook(TrainerStages.before_train_epoch) + self._inner_iter = 0 + for i, data_batch in enumerated: + data_batch = to_device(data_batch, self.device) + self.data_batch = data_batch + self._inner_iter += 1 + self.invoke_hook(TrainerStages.before_train_iter) + self.train_step(self.model, data_batch, **kwargs) + self.invoke_hook(TrainerStages.after_train_iter) + del self.data_batch + self._iter += 1 + if self._inner_iter >= self.iters_per_epoch: + break + + self.invoke_hook(TrainerStages.after_train_epoch) + self._epoch += 1 + + self.invoke_hook(TrainerStages.after_run) + + def prediction_step(self, model, inputs): + pass diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py new file mode 100644 index 00000000..14374c65 --- /dev/null +++ b/modelscope/utils/audio/audio_utils.py @@ -0,0 +1,35 @@ +import numpy as np + +SEGMENT_LENGTH_TRAIN = 16000 + + +def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN): + """ + Dataset mapping function to split one audio into segments. + It only works in batch mode. + """ + noisy_arrays = [] + for x in batch['noisy']: + length = len(x['array']) + noisy = np.array(x['array']) + for offset in range(segment_length, length, segment_length): + noisy_arrays.append(noisy[offset - segment_length:offset]) + clean_arrays = [] + for x in batch['clean']: + length = len(x['array']) + clean = np.array(x['array']) + for offset in range(segment_length, length, segment_length): + clean_arrays.append(clean[offset - segment_length:offset]) + return {'noisy': noisy_arrays, 'clean': clean_arrays} + + +def audio_norm(x): + rms = (x**2).mean()**0.5 + scalar = 10**(-25 / 20) / rms + x = x * scalar + pow_x = x**2 + avg_pow_x = pow_x.mean() + rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5 + scalarx = 10**(-25 / 20) / rmsx + x = x * scalarx + return x diff --git a/tests/trainers/audio/__init__.py b/tests/trainers/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/audio/test_ans_trainer.py b/tests/trainers/audio/test_ans_trainer.py new file mode 100644 index 00000000..176c811f --- /dev/null +++ b/tests/trainers/audio/test_ans_trainer.py @@ -0,0 +1,56 @@ +import os +import shutil +import tempfile +import unittest +from functools import partial + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.audio.audio_utils import to_segment +from modelscope.utils.test_utils import test_level + +SEGMENT_LENGTH_TEST = 640 + + +class TestANSTrainer(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/speech_frcrn_ans_cirm_16k' + + hf_ds = MsDataset.load( + 'ICASSP_2021_DNS_Challenge', split='test').to_hf_dataset() + mapped_ds = hf_ds.map( + partial(to_segment, segment_length=SEGMENT_LENGTH_TEST), + remove_columns=['duration'], + batched=True, + batch_size=2) + self.dataset = MsDataset.from_hf_dataset(mapped_ds) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + model_revision='beta', + train_dataset=self.dataset, + eval_dataset=self.dataset, + max_epochs=2, + train_iters_per_epoch=2, + val_iters_per_epoch=1, + work_dir=self.tmp_dir) + + trainer = build_trainer( + Trainers.speech_frcrn_ans_cirm_16k, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i + 1}.pth', results_files)