Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10078262master
| @@ -8,6 +8,8 @@ from modelscope.models import Model | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import WavToLists | |||
| from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, | |||
| load_bytes_from_url) | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -40,6 +42,13 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
| if self.preprocessor is None: | |||
| self.preprocessor = WavToLists() | |||
| if isinstance(audio_in, str): | |||
| # load pcm data from url if audio_in is url str | |||
| audio_in = load_bytes_from_url(audio_in) | |||
| elif isinstance(audio_in, bytes): | |||
| # load pcm data from wav data if audio_in is wave format | |||
| audio_in = extract_pcm_from_wav(audio_in) | |||
| output = self.preprocessor.forward(self.model.forward(), audio_in) | |||
| output = self.forward(output) | |||
| rst = self.postprocess(output) | |||
| @@ -42,23 +42,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: | |||
| if len(data) > 44: | |||
| frame_len = 44 | |||
| file_len = len(data) | |||
| header_fields = {} | |||
| header_fields['ChunkID'] = str(data[0:4], 'UTF-8') | |||
| header_fields['Format'] = str(data[8:12], 'UTF-8') | |||
| header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') | |||
| if header_fields['ChunkID'] == 'RIFF' and header_fields[ | |||
| 'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ': | |||
| header_fields['SubChunk1Size'] = struct.unpack('<I', | |||
| data[16:20])[0] | |||
| try: | |||
| header_fields = {} | |||
| header_fields['ChunkID'] = str(data[0:4], 'UTF-8') | |||
| header_fields['Format'] = str(data[8:12], 'UTF-8') | |||
| header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') | |||
| if header_fields['ChunkID'] == 'RIFF' and header_fields[ | |||
| 'Format'] == 'WAVE' and header_fields[ | |||
| 'Subchunk1ID'] == 'fmt ': | |||
| header_fields['SubChunk1Size'] = struct.unpack( | |||
| '<I', data[16:20])[0] | |||
| if header_fields['SubChunk1Size'] == 16: | |||
| frame_len = 44 | |||
| elif header_fields['SubChunk1Size'] == 18: | |||
| frame_len = 46 | |||
| else: | |||
| return data | |||
| if header_fields['SubChunk1Size'] == 16: | |||
| frame_len = 44 | |||
| elif header_fields['SubChunk1Size'] == 18: | |||
| frame_len = 46 | |||
| else: | |||
| return data | |||
| data = wav[frame_len:file_len] | |||
| data = wav[frame_len:file_len] | |||
| except Exception: | |||
| # no treatment | |||
| pass | |||
| return data | |||
| @@ -18,6 +18,7 @@ logger = get_logger() | |||
| POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | |||
| BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | |||
| URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/20200707_xiaoyun.wav' | |||
| POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | |||
| POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | |||
| @@ -76,6 +77,22 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| }] | |||
| } | |||
| }, | |||
| 'test_run_with_url': { | |||
| 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], | |||
| 'checking_value': '小云小云', | |||
| 'example': { | |||
| 'wav_count': | |||
| 1, | |||
| 'kws_type': | |||
| 'pcm', | |||
| 'kws_list': [{ | |||
| 'keyword': '小云小云', | |||
| 'offset': 0.69, | |||
| 'length': 1.67, | |||
| 'confidence': 0.996023 | |||
| }] | |||
| } | |||
| }, | |||
| 'test_run_with_pos_testsets': { | |||
| 'checking_item': ['recall'], | |||
| 'example': { | |||
| @@ -237,6 +254,12 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| self.check_result('test_run_with_wav_by_customized_keywords', | |||
| kws_result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_url(self): | |||
| kws_result = self.run_pipeline( | |||
| model_id=self.model_id, audio_in=URL_FILE) | |||
| self.check_result('test_run_with_url', kws_result) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_pos_testsets(self): | |||
| wav_file_path = download_and_untar( | |||