diff --git a/modelscope/pipelines/audio/asr_inference_pipeline.py b/modelscope/pipelines/audio/asr_inference_pipeline.py index b321b770..282d1184 100644 --- a/modelscope/pipelines/audio/asr_inference_pipeline.py +++ b/modelscope/pipelines/audio/asr_inference_pipeline.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Sequence, Tuple, Union import yaml @@ -9,6 +8,8 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import WavToScp +from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, + load_bytes_from_url) from modelscope.utils.constant import Frameworks, Tasks from modelscope.utils.logger import get_logger @@ -41,12 +42,20 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): self.recog_type = recog_type self.audio_format = audio_format - self.audio_in = audio_in self.audio_fs = audio_fs + if isinstance(audio_in, str): + # load pcm data from url if audio_in is url str + self.audio_in = load_bytes_from_url(audio_in) + elif isinstance(audio_in, bytes): + # load pcm data from wav data if audio_in is wave format + self.audio_in = extract_pcm_from_wav(audio_in) + else: + self.audio_in = audio_in + if recog_type is None or audio_format is None: self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( - audio_in=audio_in, + audio_in=self.audio_in, recog_type=recog_type, audio_format=audio_format) diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index 61964345..c93e0102 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -1,4 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import struct +from typing import Union +from urllib.parse import urlparse + +from modelscope.fileio.file import HTTPStorage + SEGMENT_LENGTH_TRAIN = 16000 @@ -29,3 +35,41 @@ def audio_norm(x): scalarx = 10**(-25 / 20) / rmsx x = x * scalarx return x + + +def extract_pcm_from_wav(wav: bytes) -> bytes: + data = wav + if len(data) > 44: + frame_len = 44 + file_len = len(data) + header_fields = {} + header_fields['ChunkID'] = str(data[0:4], 'UTF-8') + header_fields['Format'] = str(data[8:12], 'UTF-8') + header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') + if header_fields['ChunkID'] == 'RIFF' and header_fields[ + 'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ': + header_fields['SubChunk1Size'] = struct.unpack(' Union[bytes, str]: + result = urlparse(url) + if result.scheme is not None and len(result.scheme) > 0: + storage = HTTPStorage() + data = storage.read(url) + data = extract_pcm_from_wav(data) + else: + data = url + + return data diff --git a/tests/pipelines/test_automatic_speech_recognition.py b/tests/pipelines/test_automatic_speech_recognition.py index a83f5031..7f4ce88e 100644 --- a/tests/pipelines/test_automatic_speech_recognition.py +++ b/tests/pipelines/test_automatic_speech_recognition.py @@ -16,16 +16,11 @@ from modelscope.utils.test_utils import download_and_untar, test_level logger = get_logger() WAV_FILE = 'data/test/audios/asr_example.wav' +URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz' -AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz' -AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz' - -TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz' -TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz' - class AutomaticSpeechRecognitionTest(unittest.TestCase): action_info = { @@ -45,6 +40,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): 'checking_item': OutputKeys.TEXT, 'example': 'wav_example' }, + 'test_run_with_url_tf': { + 'checking_item': OutputKeys.TEXT, + 'example': 'wav_example' + }, 'test_run_with_wav_dataset_pytorch': { 'checking_item': OutputKeys.TEXT, 'example': 'dataset_example' @@ -132,8 +131,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_pytorch(self): - '''run with single waveform file - ''' + """run with single waveform file + """ logger.info('Run ASR test with waveform file (pytorch)...') @@ -145,8 +144,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_pcm_pytorch(self): - '''run with wav data - ''' + """run with wav data + """ logger.info('Run ASR test with wav data (pytorch)...') @@ -158,8 +157,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_wav_tf(self): - '''run with single waveform file - ''' + """run with single waveform file + """ logger.info('Run ASR test with waveform file (tensorflow)...') @@ -171,8 +170,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_pcm_tf(self): - '''run with wav data - ''' + """run with wav data + """ logger.info('Run ASR test with wav data (tensorflow)...') @@ -182,9 +181,20 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): model_id=self.am_tf_model_id, audio_in=audio, sr=sr) self.check_result('test_run_with_pcm_tf', rec_result) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_url_tf(self): + """run with single url file + """ + + logger.info('Run ASR test with url file (tensorflow)...') + + rec_result = self.run_pipeline( + model_id=self.am_tf_model_id, audio_in=URL_FILE) + self.check_result('test_run_with_url_tf', rec_result) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_wav_dataset_pytorch(self): - '''run with datasets, and audio format is waveform + """run with datasets, and audio format is waveform datasets directory: wav @@ -199,7 +209,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): ... transcript data.text # hypothesis text - ''' + """ logger.info('Run ASR test with waveform dataset (pytorch)...') logger.info('Downloading waveform testsets file ...') @@ -215,7 +225,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_wav_dataset_tf(self): - '''run with datasets, and audio format is waveform + """run with datasets, and audio format is waveform datasets directory: wav @@ -230,7 +240,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase): ... transcript data.text # hypothesis text - ''' + """ logger.info('Run ASR test with waveform dataset (tensorflow)...') logger.info('Downloading waveform testsets file ...')