shichen.fsc yingda.chen 3 years ago
parent
commit
6fd5f671fa
3 changed files with 84 additions and 21 deletions
  1. +12
    -3
      modelscope/pipelines/audio/asr_inference_pipeline.py
  2. +44
    -0
      modelscope/utils/audio/audio_utils.py
  3. +28
    -18
      tests/pipelines/test_automatic_speech_recognition.py

+ 12
- 3
modelscope/pipelines/audio/asr_inference_pipeline.py View File

@@ -1,4 +1,3 @@
import os
from typing import Any, Dict, List, Sequence, Tuple, Union

import yaml
@@ -9,6 +8,8 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
load_bytes_from_url)
from modelscope.utils.constant import Frameworks, Tasks
from modelscope.utils.logger import get_logger

@@ -41,12 +42,20 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):

self.recog_type = recog_type
self.audio_format = audio_format
self.audio_in = audio_in
self.audio_fs = audio_fs

if isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
self.audio_in = load_bytes_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
self.audio_in = extract_pcm_from_wav(audio_in)
else:
self.audio_in = audio_in

if recog_type is None or audio_format is None:
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
audio_in=audio_in,
audio_in=self.audio_in,
recog_type=recog_type,
audio_format=audio_format)



+ 44
- 0
modelscope/utils/audio/audio_utils.py View File

@@ -1,4 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import struct
from typing import Union
from urllib.parse import urlparse

from modelscope.fileio.file import HTTPStorage

SEGMENT_LENGTH_TRAIN = 16000


@@ -29,3 +35,41 @@ def audio_norm(x):
scalarx = 10**(-25 / 20) / rmsx
x = x * scalarx
return x


def extract_pcm_from_wav(wav: bytes) -> bytes:
data = wav
if len(data) > 44:
frame_len = 44
file_len = len(data)
header_fields = {}
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
header_fields['Format'] = str(data[8:12], 'UTF-8')
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
if header_fields['ChunkID'] == 'RIFF' and header_fields[
'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ':
header_fields['SubChunk1Size'] = struct.unpack('<I',
data[16:20])[0]

if header_fields['SubChunk1Size'] == 16:
frame_len = 44
elif header_fields['SubChunk1Size'] == 18:
frame_len = 46
else:
return data

data = wav[frame_len:file_len]

return data


def load_bytes_from_url(url: str) -> Union[bytes, str]:
result = urlparse(url)
if result.scheme is not None and len(result.scheme) > 0:
storage = HTTPStorage()
data = storage.read(url)
data = extract_pcm_from_wav(data)
else:
data = url

return data

+ 28
- 18
tests/pipelines/test_automatic_speech_recognition.py View File

@@ -16,16 +16,11 @@ from modelscope.utils.test_utils import download_and_untar, test_level
logger = get_logger()

WAV_FILE = 'data/test/audios/asr_example.wav'
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'

LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'

AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz'
AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz'

TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'


class AutomaticSpeechRecognitionTest(unittest.TestCase):
action_info = {
@@ -45,6 +40,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_url_tf': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_wav_dataset_pytorch': {
'checking_item': OutputKeys.TEXT,
'example': 'dataset_example'
@@ -132,8 +131,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_pytorch(self):
'''run with single waveform file
'''
"""run with single waveform file
"""

logger.info('Run ASR test with waveform file (pytorch)...')

@@ -145,8 +144,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm_pytorch(self):
'''run with wav data
'''
"""run with wav data
"""

logger.info('Run ASR test with wav data (pytorch)...')

@@ -158,8 +157,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_tf(self):
'''run with single waveform file
'''
"""run with single waveform file
"""

logger.info('Run ASR test with waveform file (tensorflow)...')

@@ -171,8 +170,8 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm_tf(self):
'''run with wav data
'''
"""run with wav data
"""

logger.info('Run ASR test with wav data (tensorflow)...')

@@ -182,9 +181,20 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm_tf', rec_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url_tf(self):
"""run with single url file
"""

logger.info('Run ASR test with url file (tensorflow)...')

rec_result = self.run_pipeline(
model_id=self.am_tf_model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url_tf', rec_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset_pytorch(self):
'''run with datasets, and audio format is waveform
"""run with datasets, and audio format is waveform
datasets directory:
<dataset_path>
wav
@@ -199,7 +209,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
...
transcript
data.text # hypothesis text
'''
"""

logger.info('Run ASR test with waveform dataset (pytorch)...')
logger.info('Downloading waveform testsets file ...')
@@ -215,7 +225,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset_tf(self):
'''run with datasets, and audio format is waveform
"""run with datasets, and audio format is waveform
datasets directory:
<dataset_path>
wav
@@ -230,7 +240,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase):
...
transcript
data.text # hypothesis text
'''
"""

logger.info('Run ASR test with waveform dataset (tensorflow)...')
logger.info('Downloading waveform testsets file ...')


Loading…
Cancel
Save