shichen.fsc yingda.chen 3 years ago
parent
commit
6fd5f671fa
3 changed files with 84 additions and 21 deletions
  1. +12
    -3
      modelscope/pipelines/audio/asr_inference_pipeline.py
  2. +44
    -0
      modelscope/utils/audio/audio_utils.py
  3. +28
    -18
      tests/pipelines/test_automatic_speech_recognition.py

+ 12
- 3
modelscope/pipelines/audio/asr_inference_pipeline.py View File

@@ -1,4 +1,3 @@
import os
from typing import Any, Dict, List, Sequence, Tuple, Union from typing import Any, Dict, List, Sequence, Tuple, Union


import yaml import yaml
@@ -9,6 +8,8 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp from modelscope.preprocessors import WavToScp
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
load_bytes_from_url)
from modelscope.utils.constant import Frameworks, Tasks from modelscope.utils.constant import Frameworks, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -41,12 +42,20 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):


self.recog_type = recog_type self.recog_type = recog_type
self.audio_format = audio_format self.audio_format = audio_format
self.audio_in = audio_in
self.audio_fs = audio_fs self.audio_fs = audio_fs


if isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
self.audio_in = load_bytes_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
self.audio_in = extract_pcm_from_wav(audio_in)
else:
self.audio_in = audio_in

if recog_type is None or audio_format is None: if recog_type is None or audio_format is None:
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking( self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
audio_in=audio_in,
audio_in=self.audio_in,
recog_type=recog_type, recog_type=recog_type,
audio_format=audio_format) audio_format=audio_format)




+ 44
- 0
modelscope/utils/audio/audio_utils.py View File

@@ -1,4 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import struct
from typing import Union
from urllib.parse import urlparse

from modelscope.fileio.file import HTTPStorage

SEGMENT_LENGTH_TRAIN = 16000 SEGMENT_LENGTH_TRAIN = 16000




@@ -29,3 +35,41 @@ def audio_norm(x):
scalarx = 10**(-25 / 20) / rmsx scalarx = 10**(-25 / 20) / rmsx
x = x * scalarx x = x * scalarx
return x return x


def extract_pcm_from_wav(wav: bytes) -> bytes:
data = wav
if len(data) > 44:
frame_len = 44
file_len = len(data)
header_fields = {}
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
header_fields['Format'] = str(data[8:12], 'UTF-8')
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
if header_fields['ChunkID'] == 'RIFF' and header_fields[
'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ':
header_fields['SubChunk1Size'] = struct.unpack('<I',
data[16:20])[0]

if header_fields['SubChunk1Size'] == 16:
frame_len = 44
elif header_fields['SubChunk1Size'] == 18:
frame_len = 46
else:
return data

data = wav[frame_len:file_len]

return data


def load_bytes_from_url(url: str) -> Union[bytes, str]:
result = urlparse(url)
if result.scheme is not None and len(result.scheme) > 0:
storage = HTTPStorage()
data = storage.read(url)
data = extract_pcm_from_wav(data)
else:
data = url

return data

+ 28
- 18
tests/pipelines/test_automatic_speech_recognition.py View File

@@ -16,16 +16,11 @@ from modelscope.utils.test_utils import download_and_untar, test_level
logger = get_logger() logger = get_logger()


WAV_FILE = 'data/test/audios/asr_example.wav' WAV_FILE = 'data/test/audios/asr_example.wav'
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'


LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz' LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz' LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'


AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz'
AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz'

TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'



class AutomaticSpeechRecognitionTest(unittest.TestCase): class AutomaticSpeechRecognitionTest(unittest.TestCase):
action_info = { action_info = {
@@ -45,6 +40,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
'checking_item': OutputKeys.TEXT, 'checking_item': OutputKeys.TEXT,
'example': 'wav_example' 'example': 'wav_example'
}, },
'test_run_with_url_tf': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_wav_dataset_pytorch': { 'test_run_with_wav_dataset_pytorch': {
'checking_item': OutputKeys.TEXT, 'checking_item': OutputKeys.TEXT,
'example': 'dataset_example' 'example': 'dataset_example'
@@ -132,8 +131,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_pytorch(self): def test_run_with_wav_pytorch(self):
'''run with single waveform file
'''
"""run with single waveform file
"""


logger.info('Run ASR test with waveform file (pytorch)...') logger.info('Run ASR test with waveform file (pytorch)...')


@@ -145,8 +144,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm_pytorch(self): def test_run_with_pcm_pytorch(self):
'''run with wav data
'''
"""run with wav data
"""


logger.info('Run ASR test with wav data (pytorch)...') logger.info('Run ASR test with wav data (pytorch)...')


@@ -158,8 +157,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_tf(self): def test_run_with_wav_tf(self):
'''run with single waveform file
'''
"""run with single waveform file
"""


logger.info('Run ASR test with waveform file (tensorflow)...') logger.info('Run ASR test with waveform file (tensorflow)...')


@@ -171,8 +170,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm_tf(self): def test_run_with_pcm_tf(self):
'''run with wav data
'''
"""run with wav data
"""


logger.info('Run ASR test with wav data (tensorflow)...') logger.info('Run ASR test with wav data (tensorflow)...')


@@ -182,9 +181,20 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
model_id=self.am_tf_model_id, audio_in=audio, sr=sr) model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm_tf', rec_result) self.check_result('test_run_with_pcm_tf', rec_result)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url_tf(self):
"""run with single url file
"""

logger.info('Run ASR test with url file (tensorflow)...')

rec_result = self.run_pipeline(
model_id=self.am_tf_model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url_tf', rec_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset_pytorch(self): def test_run_with_wav_dataset_pytorch(self):
'''run with datasets, and audio format is waveform
"""run with datasets, and audio format is waveform
datasets directory: datasets directory:
<dataset_path> <dataset_path>
wav wav
@@ -199,7 +209,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
... ...
transcript transcript
data.text # hypothesis text data.text # hypothesis text
'''
"""


logger.info('Run ASR test with waveform dataset (pytorch)...') logger.info('Run ASR test with waveform dataset (pytorch)...')
logger.info('Downloading waveform testsets file ...') logger.info('Downloading waveform testsets file ...')
@@ -215,7 +225,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset_tf(self): def test_run_with_wav_dataset_tf(self):
'''run with datasets, and audio format is waveform
"""run with datasets, and audio format is waveform
datasets directory: datasets directory:
<dataset_path> <dataset_path>
wav wav
@@ -230,7 +240,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
... ...
transcript transcript
data.text # hypothesis text data.text # hypothesis text
'''
"""


logger.info('Run ASR test with waveform dataset (tensorflow)...') logger.info('Run ASR test with waveform dataset (tensorflow)...')
logger.info('Downloading waveform testsets file ...') logger.info('Downloading waveform testsets file ...')


Loading…
Cancel
Save