| @@ -6,6 +6,7 @@ import numpy as np | |||
| import soundfile as sf | |||
| import torch | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| @@ -34,11 +35,12 @@ class ANSPipeline(Pipeline): | |||
| super().__init__(model=model, **kwargs) | |||
| self.model.eval() | |||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||
| if isinstance(inputs, bytes): | |||
| data1, fs = sf.read(io.BytesIO(inputs)) | |||
| elif isinstance(inputs, str): | |||
| data1, fs = sf.read(inputs) | |||
| file_bytes = File.read(inputs) | |||
| data1, fs = sf.read(io.BytesIO(file_bytes)) | |||
| else: | |||
| raise TypeError(f'Unsupported type {type(inputs)}.') | |||
| if len(data1.shape) > 1: | |||
| @@ -50,7 +52,8 @@ class ANSPipeline(Pipeline): | |||
| inputs = np.reshape(data, [1, data.shape[0]]) | |||
| return {'ndarray': inputs, 'nsamples': data.shape[0]} | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| ndarray = inputs['ndarray'] | |||
| if isinstance(ndarray, torch.Tensor): | |||
| ndarray = ndarray.cpu().numpy() | |||
| @@ -2,6 +2,7 @@ import io | |||
| import wave | |||
| from typing import Any, Dict | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| @@ -39,6 +40,8 @@ class KWSFarfieldPipeline(Pipeline): | |||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||
| if isinstance(inputs, bytes): | |||
| return dict(input_file=inputs) | |||
| elif isinstance(inputs, str): | |||
| return dict(input_file=inputs) | |||
| elif isinstance(inputs, Dict): | |||
| return inputs | |||
| else: | |||
| @@ -47,6 +50,8 @@ class KWSFarfieldPipeline(Pipeline): | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| input_file = inputs['input_file'] | |||
| if isinstance(input_file, str): | |||
| input_file = File.read(input_file) | |||
| if isinstance(input_file, bytes): | |||
| input_file = io.BytesIO(input_file) | |||
| self.frame_count = 0 | |||
| @@ -6,9 +6,10 @@ import numpy as np | |||
| import scipy.io.wavfile as wav | |||
| import torch | |||
| from modelscope.fileio import File | |||
| from modelscope.preprocessors import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.utils.constant import Fields | |||
| from . import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| def load_kaldi_feature_transform(filename): | |||
| @@ -201,7 +202,8 @@ class LinearAECAndFbank(Preprocessor): | |||
| if isinstance(inputs, bytes): | |||
| inputs = io.BytesIO(inputs) | |||
| elif isinstance(inputs, str): | |||
| pass | |||
| file_bytes = File.read(inputs) | |||
| inputs = io.BytesIO(file_bytes) | |||
| else: | |||
| raise TypeError(f'Unsupported input type: {type(inputs)}.') | |||
| sample_rate, data = wav.read(inputs) | |||
| @@ -6,6 +6,9 @@ from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' | |||
| TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ | |||
| 'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \ | |||
| '?Revision=master&FilePath=examples/3ch_nihaomiya.wav' | |||
| class KWSFarfieldTest(unittest.TestCase): | |||
| @@ -13,7 +16,7 @@ class KWSFarfieldTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_normal(self): | |||
| kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||
| inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)} | |||
| @@ -21,6 +24,13 @@ class KWSFarfieldTest(unittest.TestCase): | |||
| self.assertEqual(len(result['kws_list']), 5) | |||
| print(result['kws_list'][-1]) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_url(self): | |||
| kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||
| result = kws(TEST_SPEECH_URL) | |||
| self.assertEqual(len(result['kws_list']), 5) | |||
| print(result['kws_list'][-1]) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_output(self): | |||
| kws = pipeline(Tasks.keyword_spotting, model=self.model_id) | |||
| @@ -9,8 +9,17 @@ from modelscope.utils.test_utils import test_level | |||
| NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' | |||
| FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' | |||
| NEAREND_MIC_URL = 'https://modelscope.cn/api/v1/models/damo/' \ | |||
| 'speech_dfsmn_aec_psm_16k/repo?Revision=master' \ | |||
| '&FilePath=examples/nearend_mic.wav' | |||
| FAREND_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ | |||
| 'speech_dfsmn_aec_psm_16k/repo?Revision=master' \ | |||
| '&FilePath=examples/farend_speech.wav' | |||
| NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' | |||
| NOISE_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ | |||
| 'speech_frcrn_ans_cirm_16k/repo?Revision=master' \ | |||
| '&FilePath=examples/speech_with_noise.wav' | |||
| class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @@ -18,7 +27,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| pass | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_aec(self): | |||
| model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
| input = { | |||
| @@ -30,6 +39,18 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| aec(input, output_path=output_path) | |||
| print(f'Processed audio saved to {output_path}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_aec_url(self): | |||
| model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
| input = { | |||
| 'nearend_mic': NEAREND_MIC_URL, | |||
| 'farend_speech': FAREND_SPEECH_URL | |||
| } | |||
| aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) | |||
| output_path = os.path.abspath('output.wav') | |||
| aec(input, output_path=output_path) | |||
| print(f'Processed audio saved to {output_path}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_aec_bytes(self): | |||
| model_id = 'damo/speech_dfsmn_aec_psm_16k' | |||
| @@ -62,7 +83,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| aec(inputs, output_path=output_path) | |||
| print(f'Processed audio saved to {output_path}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_ans(self): | |||
| model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||
| ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) | |||
| @@ -71,6 +92,14 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| output_path=output_path) | |||
| print(f'Processed audio saved to {output_path}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_ans_url(self): | |||
| model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||
| ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) | |||
| output_path = os.path.abspath('output.wav') | |||
| ans(NOISE_SPEECH_URL, output_path=output_path) | |||
| print(f'Processed audio saved to {output_path}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_ans_bytes(self): | |||
| model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||