Browse Source

[to #42322933] audio pipelines accept url as input

master
bin.xue 3 years ago
parent
commit
4be7737122
5 changed files with 58 additions and 9 deletions
  1. +6
    -3
      modelscope/pipelines/audio/ans_pipeline.py
  2. +5
    -0
      modelscope/pipelines/audio/kws_farfield_pipeline.py
  3. +5
    -3
      modelscope/preprocessors/audio.py
  4. +11
    -1
      tests/pipelines/test_key_word_spotting_farfield.py
  5. +31
    -2
      tests/pipelines/test_speech_signal_process.py

+ 6
- 3
modelscope/pipelines/audio/ans_pipeline.py View File

@@ -6,6 +6,7 @@ import numpy as np
import soundfile as sf
import torch

from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
@@ -34,11 +35,12 @@ class ANSPipeline(Pipeline):
super().__init__(model=model, **kwargs)
self.model.eval()

def preprocess(self, inputs: Input) -> Dict[str, Any]:
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes):
data1, fs = sf.read(io.BytesIO(inputs))
elif isinstance(inputs, str):
data1, fs = sf.read(inputs)
file_bytes = File.read(inputs)
data1, fs = sf.read(io.BytesIO(file_bytes))
else:
raise TypeError(f'Unsupported type {type(inputs)}.')
if len(data1.shape) > 1:
@@ -50,7 +52,8 @@ class ANSPipeline(Pipeline):
inputs = np.reshape(data, [1, data.shape[0]])
return {'ndarray': inputs, 'nsamples': data.shape[0]}

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
ndarray = inputs['ndarray']
if isinstance(ndarray, torch.Tensor):
ndarray = ndarray.cpu().numpy()


+ 5
- 0
modelscope/pipelines/audio/kws_farfield_pipeline.py View File

@@ -2,6 +2,7 @@ import io
import wave
from typing import Any, Dict

from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
@@ -39,6 +40,8 @@ class KWSFarfieldPipeline(Pipeline):
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, bytes):
return dict(input_file=inputs)
elif isinstance(inputs, str):
return dict(input_file=inputs)
elif isinstance(inputs, Dict):
return inputs
else:
@@ -47,6 +50,8 @@ class KWSFarfieldPipeline(Pipeline):
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
input_file = inputs['input_file']
if isinstance(input_file, str):
input_file = File.read(input_file)
if isinstance(input_file, bytes):
input_file = io.BytesIO(input_file)
self.frame_count = 0


+ 5
- 3
modelscope/preprocessors/audio.py View File

@@ -6,9 +6,10 @@ import numpy as np
import scipy.io.wavfile as wav
import torch

from modelscope.fileio import File
from modelscope.preprocessors import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields
from . import Preprocessor
from .builder import PREPROCESSORS


def load_kaldi_feature_transform(filename):
@@ -201,7 +202,8 @@ class LinearAECAndFbank(Preprocessor):
if isinstance(inputs, bytes):
inputs = io.BytesIO(inputs)
elif isinstance(inputs, str):
pass
file_bytes = File.read(inputs)
inputs = io.BytesIO(file_bytes)
else:
raise TypeError(f'Unsupported input type: {type(inputs)}.')
sample_rate, data = wav.read(inputs)


+ 11
- 1
tests/pipelines/test_key_word_spotting_farfield.py View File

@@ -6,6 +6,9 @@ from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
TEST_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_kws_char_farfield_16k_nihaomiya/repo' \
'?Revision=master&FilePath=examples/3ch_nihaomiya.wav'


class KWSFarfieldTest(unittest.TestCase):
@@ -13,7 +16,7 @@ class KWSFarfieldTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_normal(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)}
@@ -21,6 +24,13 @@ class KWSFarfieldTest(unittest.TestCase):
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_url(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
result = kws(TEST_SPEECH_URL)
self.assertEqual(len(result['kws_list']), 5)
print(result['kws_list'][-1])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_output(self):
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)


+ 31
- 2
tests/pipelines/test_speech_signal_process.py View File

@@ -9,8 +9,17 @@ from modelscope.utils.test_utils import test_level

NEAREND_MIC_FILE = 'data/test/audios/nearend_mic.wav'
FAREND_SPEECH_FILE = 'data/test/audios/farend_speech.wav'
NEAREND_MIC_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_aec_psm_16k/repo?Revision=master' \
'&FilePath=examples/nearend_mic.wav'
FAREND_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_dfsmn_aec_psm_16k/repo?Revision=master' \
'&FilePath=examples/farend_speech.wav'

NOISE_SPEECH_FILE = 'data/test/audios/speech_with_noise.wav'
NOISE_SPEECH_URL = 'https://modelscope.cn/api/v1/models/damo/' \
'speech_frcrn_ans_cirm_16k/repo?Revision=master' \
'&FilePath=examples/speech_with_noise.wav'


class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
@@ -18,7 +27,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
pass

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec(self):
model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = {
@@ -30,6 +39,18 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
aec(input, output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_aec_url(self):
model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = {
'nearend_mic': NEAREND_MIC_URL,
'farend_speech': FAREND_SPEECH_URL
}
aec = pipeline(Tasks.acoustic_echo_cancellation, model=model_id)
output_path = os.path.abspath('output.wav')
aec(input, output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec_bytes(self):
model_id = 'damo/speech_dfsmn_aec_psm_16k'
@@ -62,7 +83,7 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
aec(inputs, output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans(self):
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
@@ -71,6 +92,14 @@ class SpeechSignalProcessTest(unittest.TestCase, DemoCompatibilityCheck):
output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_ans_url(self):
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id)
output_path = os.path.abspath('output.wav')
ans(NOISE_SPEECH_URL, output_path=output_path)
print(f'Processed audio saved to {output_path}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans_bytes(self):
model_id = 'damo/speech_frcrn_ans_cirm_16k'


Loading…
Cancel
Save