From 84c384cc57152005e8e45422cdecc4817cb042e8 Mon Sep 17 00:00:00 2001 From: "shichen.fsc" Date: Fri, 9 Sep 2022 10:06:20 +0800 Subject: [PATCH] [to #42322933] add httpurl support for KWS Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10078262 --- .../pipelines/audio/kws_kwsbp_pipeline.py | 9 +++++ modelscope/utils/audio/audio_utils.py | 35 +++++++++++-------- tests/pipelines/test_key_word_spotting.py | 23 ++++++++++++ 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 1f31766a..866b8d0b 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -8,6 +8,8 @@ from modelscope.models import Model from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import WavToLists +from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, + load_bytes_from_url) from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -40,6 +42,13 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): if self.preprocessor is None: self.preprocessor = WavToLists() + if isinstance(audio_in, str): + # load pcm data from url if audio_in is url str + audio_in = load_bytes_from_url(audio_in) + elif isinstance(audio_in, bytes): + # load pcm data from wav data if audio_in is wave format + audio_in = extract_pcm_from_wav(audio_in) + output = self.preprocessor.forward(self.model.forward(), audio_in) output = self.forward(output) rst = self.postprocess(output) diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index c93e0102..4c2c45cc 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -42,23 +42,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: if len(data) > 44: frame_len = 44 file_len = len(data) - header_fields = {} - header_fields['ChunkID'] = str(data[0:4], 'UTF-8') - header_fields['Format'] = str(data[8:12], 'UTF-8') - header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') - if header_fields['ChunkID'] == 'RIFF' and header_fields[ - 'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ': - header_fields['SubChunk1Size'] = struct.unpack('= 0, 'skip test in current test level') + def test_run_with_url(self): + kws_result = self.run_pipeline( + model_id=self.model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url', kws_result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_pos_testsets(self): wav_file_path = download_and_untar(