Browse Source

[to #42322933] audio pipelines accept url as input

master
bin.xue 3 years ago
parent
commit
4be7737122
5 changed files with 58 additions and 9 deletions
  1. +6
    -3
      modelscope/pipelines/audio/ans_pipeline.py
  2. +5
    -0
      modelscope/pipelines/audio/kws_farfield_pipeline.py
  3. +5
    -3
      modelscope/preprocessors/audio.py
  4. +11
    -1
      tests/pipelines/test_key_word_spotting_farfield.py
  5. +31
    -2
      tests/pipelines/test_speech_signal_process.py

+ 6
- 3
modelscope/pipelines/audio/ans_pipeline.py View File

@@ -6,6 +6,7 @@ import numpy as np
import soundfile as sf import soundfile as sf
import torch import torch


from modelscope.fileio import File
from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
@@ -34,11 +35,12 @@ class ANSPipeline(Pipeline):
super().__init__(model=model, **kwargs) super().__init__(model=model, **kwargs)
self.model.eval() self.model.eval()


def preprocess(self, inputs: Input) -> Dict[str, Any]:
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
data1, fs = sf.read(io.BytesIO(inputs)) data1, fs = sf.read(io.BytesIO(inputs))
elif isinstance(inputs, str): elif isinstance(inputs, str):
data1, fs = sf.read(inputs)
file_bytes = File.read(inputs)
data1, fs = sf.read(io.BytesIO(file_bytes))
else: else:
raise TypeError(f'Unsupported type {type(inputs)}.') raise TypeError(f'Unsupported type {type(inputs)}.')
if len(data1.shape) > 1: if len(data1.shape) > 1:
@@ -50,7 +52,8 @@ class ANSPipeline(Pipeline):
inputs = np.reshape(data, [1, data.shape[0]]) inputs = np.reshape(data, [1, data.shape[0]])
return {'ndarray': inputs, 'nsamples': data.shape[0]} return {'ndarray': inputs, 'nsamples': data.shape[0]}


def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
ndarray = inputs['ndarray'] ndarray = inputs['ndarray']
if isinstance(ndarray, torch.Tensor): if isinstance(ndarray, torch.Tensor):
ndarray = ndarray.cpu().numpy() ndarray = ndarray.cpu().numpy()


+ 5
- 0
modelscope/pipelines/audio/kws_farfield_pipeline.py View File

@@ -2,6 +2,7 @@ import io
import wave import wave
from typing import Any, Dict from typing import Any, Dict


from modelscope.fileio import File
from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
@@ -39,6 +40,8 @@ class KWSFarfieldPipeline(Pipeline):
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
return dict(input_file=inputs) return dict(input_file=inputs)
elif isinstance(inputs, str):
return dict(input_file=inputs)
elif isinstance(inputs, Dict): elif isinstance(inputs, Dict):
return inputs return inputs
else: else:
@@ -47,6 +50,8 @@ class KWSFarfieldPipeline(Pipeline):
def forward(self, inputs: Dict[str, Any], def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]: **forward_params) -> Dict[str, Any]:
input_file = inputs['input_file'] input_file = inputs['input_file']
if isinstance(input_file, str):
input_file = File.read(input_file)
if isinstance(input_file, bytes): if isinstance(input_file, bytes):
input_file = io.BytesIO(input_file) input_file = io.BytesIO(input_file)
self.frame_count = 0 self.frame_count = 0


+ 5
- 3
modelscope/preprocessors/audio.py View File

@@ -6,9 +6,10 @@ import numpy as np
import scipy.io.wavfile as wav import scipy.io.wavfile as wav
import torch import torch


from modelscope.fileio import File
from modelscope.preprocessors import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields from modelscope.utils.constant import Fields
from . import Preprocessor
from .builder import PREPROCESSORS




def load_kaldi_feature_transform(filename): def load_kaldi_feature_transform(filename):
@@ -201,7 +202,8 @@ class LinearAECAndFbank(Preprocessor):
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
inputs = io.BytesIO(inputs) inputs = io.BytesIO(inputs)
elif isinstance(inputs, str): elif isinstance(inputs, str):
pass
file_bytes = File.read(inputs)
inputs = io.BytesIO(file_bytes)
else: else:
raise TypeError(f'Unsupported input type: {type(inputs)}.') raise TypeError(f'Unsupported input type: {type(inputs)}.')
sample_rate, data = wav.read(inputs) sample_rate, data = wav.read(inputs)


+ 11
- 1
tests/pipelines/test_key_word_spotting_farfield.py View File

@@ -6,6 +6,9 @@ from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level


TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav' TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav'




class KWSFarfieldTest(unittest.TestCase): class KWSFarfieldTest(unittest.TestCase):
@@ -13,7 +16,7 @@ class KWSFarfieldTest(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_normal(self): def test_normal(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id) kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)} inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)}
@@ -21,6 +24,13 @@ class KWSFarfieldTest(unittest.TestCase):
self.assertEqual(len(result['kws_list']), 5) self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1]) print(result['kws_list'][-1])


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_url(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
result = kws(TEST_SPEECH_URL)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_output(self): def test_output(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id) kws = pipeline(Tasks.keyword_spotting, model=self.model_id)


+ 31
- 2
tests/pipelines/test_speech_signal_process.py View File

@@ -9,8 +9,17 @@ from modelscope.utils.test_utils import test_level


NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav' NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav'
FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav' FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav'
NEAREND_MIC_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_aec_psm_16k/repo?Revision=master' \
'&FilePath=examples/nearend_mic.wav'
FAREND_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_aec_psm_16k/repo?Revision=master' \
'&FilePath=examples/farend_speech.wav'


NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav' NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav'
NOISE_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_frcrn_ans_cirm_16k/repo?Revision=master' \
'&FilePath=examples/speech_with_noise.wav'




class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck): class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
@@ -18,7 +27,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None: def setUp(self) -> None:
pass pass


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec(self): def test_aec(self):
model_id = 'damo/speech_dfsmn_aec_psm_16k' model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = { input = {
@@ -30,6 +39,18 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
aec(input, output_path=output_path) aec(input, output_path=output_path)
print(f'Processed audio saved to {output_path}') print(f'Processed audio saved to {output_path}')


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_aec_url(self):
model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = {
'nearend_mic': NEAREND_MIC_URL,
'farend_speech': FAREND_SPEECH_URL
}
aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id)
output_path = os.path.abspath('output.wav')
aec(input, output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec_bytes(self): def test_aec_bytes(self):
model_id = 'damo/speech_dfsmn_aec_psm_16k' model_id = 'damo/speech_dfsmn_aec_psm_16k'
@@ -62,7 +83,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
aec(inputs, output_path=output_path) aec(inputs, output_path=output_path)
print(f'Processed audio saved to {output_path}') print(f'Processed audio saved to {output_path}')


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans(self): def test_ans(self):
model_id = 'damo/speech_frcrn_ans_cirm_16k' model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
@@ -71,6 +92,14 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
output_path=output_path) output_path=output_path)
print(f'Processed audio saved to {output_path}') print(f'Processed audio saved to {output_path}')


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_ans_url(self):
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
output_path = os.path.abspath('output.wav')
ans(NOISE_SPEECH_URL, output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans_bytes(self): def test_ans_bytes(self):
model_id = 'damo/speech_frcrn_ans_cirm_16k' model_id = 'damo/speech_frcrn_ans_cirm_16k'


Loading…
Cancel
Save