diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index 5ed4d769..62399684 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -6,6 +6,7 @@ import numpy as np import soundfile as sf import torch +from modelscope.fileio import File from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline @@ -34,11 +35,12 @@ class ANSPipeline(Pipeline): super().__init__(model=model, **kwargs) self.model.eval() - def preprocess(self, inputs: Input) -> Dict[str, Any]: + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: if isinstance(inputs, bytes): data1, fs = sf.read(io.BytesIO(inputs)) elif isinstance(inputs, str): - data1, fs = sf.read(inputs) + file_bytes = File.read(inputs) + data1, fs = sf.read(io.BytesIO(file_bytes)) else: raise TypeError(f'Unsupported type {type(inputs)}.') if len(data1.shape) > 1: @@ -50,7 +52,8 @@ class ANSPipeline(Pipeline): inputs = np.reshape(data, [1, data.shape[0]]) return {'ndarray': inputs, 'nsamples': data.shape[0]} - def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: ndarray = inputs['ndarray'] if isinstance(ndarray, torch.Tensor): ndarray = ndarray.cpu().numpy() diff --git a/modelscope/pipelines/audio/kws_farfield_pipeline.py b/modelscope/pipelines/audio/kws_farfield_pipeline.py index a114e7fb..62848a27 100644 --- a/modelscope/pipelines/audio/kws_farfield_pipeline.py +++ b/modelscope/pipelines/audio/kws_farfield_pipeline.py @@ -2,6 +2,7 @@ import io import wave from typing import Any, Dict +from modelscope.fileio import File from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline @@ -39,6 +40,8 @@ class KWSFarfieldPipeline(Pipeline): def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: if isinstance(inputs, bytes): return dict(input_file=inputs) + elif isinstance(inputs, str): + return dict(input_file=inputs) elif isinstance(inputs, Dict): return inputs else: @@ -47,6 +50,8 @@ class KWSFarfieldPipeline(Pipeline): def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: input_file = inputs['input_file'] + if isinstance(input_file, str): + input_file = File.read(input_file) if isinstance(input_file, bytes): input_file = io.BytesIO(input_file) self.frame_count = 0 diff --git a/modelscope/preprocessors/audio.py b/modelscope/preprocessors/audio.py index 10057034..dd2f1fc1 100644 --- a/modelscope/preprocessors/audio.py +++ b/modelscope/preprocessors/audio.py @@ -6,9 +6,10 @@ import numpy as np import scipy.io.wavfile as wav import torch +from modelscope.fileio import File +from modelscope.preprocessors import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS from modelscope.utils.constant import Fields -from . import Preprocessor -from .builder import PREPROCESSORS def load_kaldi_feature_transform(filename): @@ -201,7 +202,8 @@ class LinearAECAndFbank(Preprocessor): if isinstance(inputs, bytes): inputs = io.BytesIO(inputs) elif isinstance(inputs, str): - pass + file_bytes = File.read(inputs) + inputs = io.BytesIO(file_bytes) else: raise TypeError(f'Unsupported input type: {type(inputs)}.') sample_rate, data = wav.read(inputs) diff --git a/tests/pipelines/test_key_word_spotting_farfield.py b/tests/pipelines/test_key_word_spotting_farfield.py index 4a732950..1b23a6a7 100644 --- a/tests/pipelines/test_key_word_spotting_farfield.py +++ b/tests/pipelines/test_key_word_spotting_farfield.py @@ -6,6 +6,9 @@ from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' +TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ + 'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \ + '?Revision=master&FilePath=examples/3ch_nihaomiya.wav' class KWSFarfieldTest(unittest.TestCase): @@ -13,7 +16,7 @@ class KWSFarfieldTest(unittest.TestCase): def setUp(self) -> None: self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_normal(self): kws = pipeline(Tasks.keyword_spotting, model=self.model_id) inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)} @@ -21,6 +24,13 @@ class KWSFarfieldTest(unittest.TestCase): self.assertEqual(len(result['kws_list']), 5) print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_url(self): + kws = pipeline(Tasks.keyword_spotting, model=self.model_id) + result = kws(TEST_SPEECH_URL) + self.assertEqual(len(result['kws_list']), 5) + print(result['kws_list'][-1]) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_output(self): kws = pipeline(Tasks.keyword_spotting, model=self.model_id) diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index 8ca6bf1d..e1987c28 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -9,8 +9,17 @@ from modelscope.utils.test_utils import test_level NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' +NEAREND_MIC_URL = 'https://modelscope.cn/api/v1/models/damo/' \ + 'speech_dfsmn_aec_psm_16k/repo?Revision=master' \ + '&FilePath=examples/nearend_mic.wav' +FAREND_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ + 'speech_dfsmn_aec_psm_16k/repo?Revision=master' \ + '&FilePath=examples/farend_speech.wav' NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' +NOISE_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \ + 'speech_frcrn_ans_cirm_16k/repo?Revision=master' \ + '&FilePath=examples/speech_with_noise.wav' class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): @@ -18,7 +27,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): def setUp(self) -> None: pass - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_aec(self): model_id = 'damo/speech_dfsmn_aec_psm_16k' input = { @@ -30,6 +39,18 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): aec(input, output_path=output_path) print(f'Processed audio saved to {output_path}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_aec_url(self): + model_id = 'damo/speech_dfsmn_aec_psm_16k' + input = { + 'nearend_mic': NEAREND_MIC_URL, + 'farend_speech': FAREND_SPEECH_URL + } + aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id) + output_path = os.path.abspath('output.wav') + aec(input, output_path=output_path) + print(f'Processed audio saved to {output_path}') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_aec_bytes(self): model_id = 'damo/speech_dfsmn_aec_psm_16k' @@ -62,7 +83,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): aec(inputs, output_path=output_path) print(f'Processed audio saved to {output_path}') - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ans(self): model_id = 'damo/speech_frcrn_ans_cirm_16k' ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) @@ -71,6 +92,14 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): output_path=output_path) print(f'Processed audio saved to {output_path}') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_ans_url(self): + model_id = 'damo/speech_frcrn_ans_cirm_16k' + ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) + output_path = os.path.abspath('output.wav') + ans(NOISE_SPEECH_URL, output_path=output_path) + print(f'Processed audio saved to {output_path}') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ans_bytes(self): model_id = 'damo/speech_frcrn_ans_cirm_16k'