Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10078262master
| @@ -8,6 +8,8 @@ from modelscope.models import Model | |||||
| from modelscope.pipelines.base import Pipeline | from modelscope.pipelines.base import Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import WavToLists | from modelscope.preprocessors import WavToLists | ||||
| from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, | |||||
| load_bytes_from_url) | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -40,6 +42,13 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| if self.preprocessor is None: | if self.preprocessor is None: | ||||
| self.preprocessor = WavToLists() | self.preprocessor = WavToLists() | ||||
| if isinstance(audio_in, str): | |||||
| # load pcm data from url if audio_in is url str | |||||
| audio_in = load_bytes_from_url(audio_in) | |||||
| elif isinstance(audio_in, bytes): | |||||
| # load pcm data from wav data if audio_in is wave format | |||||
| audio_in = extract_pcm_from_wav(audio_in) | |||||
| output = self.preprocessor.forward(self.model.forward(), audio_in) | output = self.preprocessor.forward(self.model.forward(), audio_in) | ||||
| output = self.forward(output) | output = self.forward(output) | ||||
| rst = self.postprocess(output) | rst = self.postprocess(output) | ||||
| @@ -42,23 +42,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: | |||||
| if len(data) > 44: | if len(data) > 44: | ||||
| frame_len = 44 | frame_len = 44 | ||||
| file_len = len(data) | file_len = len(data) | ||||
| header_fields = {} | |||||
| header_fields['ChunkID'] = str(data[0:4], 'UTF-8') | |||||
| header_fields['Format'] = str(data[8:12], 'UTF-8') | |||||
| header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') | |||||
| if header_fields['ChunkID'] == 'RIFF' and header_fields[ | |||||
| 'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ': | |||||
| header_fields['SubChunk1Size'] = struct.unpack('<I', | |||||
| data[16:20])[0] | |||||
| try: | |||||
| header_fields = {} | |||||
| header_fields['ChunkID'] = str(data[0:4], 'UTF-8') | |||||
| header_fields['Format'] = str(data[8:12], 'UTF-8') | |||||
| header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') | |||||
| if header_fields['ChunkID'] == 'RIFF' and header_fields[ | |||||
| 'Format'] == 'WAVE' and header_fields[ | |||||
| 'Subchunk1ID'] == 'fmt ': | |||||
| header_fields['SubChunk1Size'] = struct.unpack( | |||||
| '<I', data[16:20])[0] | |||||
| if header_fields['SubChunk1Size'] == 16: | |||||
| frame_len = 44 | |||||
| elif header_fields['SubChunk1Size'] == 18: | |||||
| frame_len = 46 | |||||
| else: | |||||
| return data | |||||
| if header_fields['SubChunk1Size'] == 16: | |||||
| frame_len = 44 | |||||
| elif header_fields['SubChunk1Size'] == 18: | |||||
| frame_len = 46 | |||||
| else: | |||||
| return data | |||||
| data = wav[frame_len:file_len] | |||||
| data = wav[frame_len:file_len] | |||||
| except Exception: | |||||
| # no treatment | |||||
| pass | |||||
| return data | return data | ||||
| @@ -18,6 +18,7 @@ logger = get_logger() | |||||
| POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | ||||
| BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | ||||
| URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/20200707_xiaoyun.wav' | |||||
| POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | ||||
| POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | ||||
| @@ -76,6 +77,22 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| }] | }] | ||||
| } | } | ||||
| }, | }, | ||||
| 'test_run_with_url': { | |||||
| 'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], | |||||
| 'checking_value': '小云小云', | |||||
| 'example': { | |||||
| 'wav_count': | |||||
| 1, | |||||
| 'kws_type': | |||||
| 'pcm', | |||||
| 'kws_list': [{ | |||||
| 'keyword': '小云小云', | |||||
| 'offset': 0.69, | |||||
| 'length': 1.67, | |||||
| 'confidence': 0.996023 | |||||
| }] | |||||
| } | |||||
| }, | |||||
| 'test_run_with_pos_testsets': { | 'test_run_with_pos_testsets': { | ||||
| 'checking_item': ['recall'], | 'checking_item': ['recall'], | ||||
| 'example': { | 'example': { | ||||
| @@ -237,6 +254,12 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| self.check_result('test_run_with_wav_by_customized_keywords', | self.check_result('test_run_with_wav_by_customized_keywords', | ||||
| kws_result) | kws_result) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_url(self): | |||||
| kws_result = self.run_pipeline( | |||||
| model_id=self.model_id, audio_in=URL_FILE) | |||||
| self.check_result('test_run_with_url', kws_result) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_run_with_pos_testsets(self): | def test_run_with_pos_testsets(self): | ||||
| wav_file_path = download_and_untar( | wav_file_path = download_and_untar( | ||||