You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_automatic_speech_recognition.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import sys
  5. import tarfile
  6. import unittest
  7. from typing import Any, Dict, Union
  8. import numpy as np
  9. import requests
  10. import soundfile
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines import pipeline
  13. from modelscope.utils.constant import ColorCodes, Tasks
  14. from modelscope.utils.logger import get_logger
  15. from modelscope.utils.test_utils import download_and_untar, test_level
  16. logger = get_logger()
  17. WAV_FILE = 'data/test/audios/asr_example.wav'
  18. LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
  19. LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'
  20. AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz'
  21. AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz'
  22. TFRECORD_TESTSETS_FILE = 'tfrecord.tar.gz'
  23. TFRECORD_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/tfrecord.tar.gz'
  24. def un_tar_gz(fname, dirs):
  25. t = tarfile.open(fname)
  26. t.extractall(path=dirs)
  27. class AutomaticSpeechRecognitionTest(unittest.TestCase):
  28. action_info = {
  29. 'test_run_with_wav_pytorch': {
  30. 'checking_item': OutputKeys.TEXT,
  31. 'example': 'wav_example'
  32. },
  33. 'test_run_with_pcm_pytorch': {
  34. 'checking_item': OutputKeys.TEXT,
  35. 'example': 'wav_example'
  36. },
  37. 'test_run_with_wav_tf': {
  38. 'checking_item': OutputKeys.TEXT,
  39. 'example': 'wav_example'
  40. },
  41. 'test_run_with_pcm_tf': {
  42. 'checking_item': OutputKeys.TEXT,
  43. 'example': 'wav_example'
  44. },
  45. 'test_run_with_wav_dataset_pytorch': {
  46. 'checking_item': OutputKeys.TEXT,
  47. 'example': 'dataset_example'
  48. },
  49. 'test_run_with_wav_dataset_tf': {
  50. 'checking_item': OutputKeys.TEXT,
  51. 'example': 'dataset_example'
  52. },
  53. 'test_run_with_ark_dataset': {
  54. 'checking_item': OutputKeys.TEXT,
  55. 'example': 'dataset_example'
  56. },
  57. 'test_run_with_tfrecord_dataset': {
  58. 'checking_item': OutputKeys.TEXT,
  59. 'example': 'dataset_example'
  60. },
  61. 'dataset_example': {
  62. 'Wrd': 49532, # the number of words
  63. 'Snt': 5000, # the number of sentences
  64. 'Corr': 47276, # the number of correct words
  65. 'Ins': 49, # the number of insert words
  66. 'Del': 152, # the number of delete words
  67. 'Sub': 2207, # the number of substitution words
  68. 'wrong_words': 2408, # the number of wrong words
  69. 'wrong_sentences': 1598, # the number of wrong sentences
  70. 'Err': 4.86, # WER/CER
  71. 'S.Err': 31.96 # SER
  72. },
  73. 'wav_example': {
  74. 'text': '每一天都要快乐喔'
  75. }
  76. }
  77. def setUp(self) -> None:
  78. self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
  79. self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
  80. # this temporary workspace dir will store waveform files
  81. self.workspace = os.path.join(os.getcwd(), '.tmp')
  82. if not os.path.exists(self.workspace):
  83. os.mkdir(self.workspace)
  84. def tearDown(self) -> None:
  85. # remove workspace dir (.tmp)
  86. shutil.rmtree(self.workspace, ignore_errors=True)
  87. def run_pipeline(self,
  88. model_id: str,
  89. audio_in: Union[str, bytes],
  90. sr: int = 16000) -> Dict[str, Any]:
  91. inference_16k_pipline = pipeline(
  92. task=Tasks.auto_speech_recognition, model=model_id)
  93. rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
  94. return rec_result
  95. def log_error(self, functions: str, result: Dict[str, Any]) -> None:
  96. logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
  97. + ColorCodes.END)
  98. logger.error(
  99. ColorCodes.MAGENTA + functions + ' correct result example:'
  100. + ColorCodes.YELLOW
  101. + str(self.action_info[self.action_info[functions]['example']])
  102. + ColorCodes.END)
  103. raise ValueError('asr result is mismatched')
  104. def check_result(self, functions: str, result: Dict[str, Any]) -> None:
  105. if result.__contains__(self.action_info[functions]['checking_item']):
  106. logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
  107. + ColorCodes.END)
  108. logger.info(
  109. ColorCodes.YELLOW
  110. + str(result[self.action_info[functions]['checking_item']])
  111. + ColorCodes.END)
  112. else:
  113. self.log_error(functions, result)
  114. def wav2bytes(self, wav_file):
  115. audio, fs = soundfile.read(wav_file)
  116. # float32 -> int16
  117. audio = np.asarray(audio)
  118. dtype = np.dtype('int16')
  119. i = np.iinfo(dtype)
  120. abs_max = 2**(i.bits - 1)
  121. offset = i.min + abs_max
  122. audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
  123. # int16(PCM_16) -> byte
  124. audio = audio.tobytes()
  125. return audio, fs
  126. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  127. def test_run_with_wav_pytorch(self):
  128. '''run with single waveform file
  129. '''
  130. logger.info('Run ASR test with waveform file (pytorch)...')
  131. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  132. rec_result = self.run_pipeline(
  133. model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
  134. self.check_result('test_run_with_wav_pytorch', rec_result)
  135. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  136. def test_run_with_pcm_pytorch(self):
  137. '''run with wav data
  138. '''
  139. logger.info('Run ASR test with wav data (pytorch)...')
  140. audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  141. rec_result = self.run_pipeline(
  142. model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
  143. self.check_result('test_run_with_pcm_pytorch', rec_result)
  144. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  145. def test_run_with_wav_tf(self):
  146. '''run with single waveform file
  147. '''
  148. logger.info('Run ASR test with waveform file (tensorflow)...')
  149. wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
  150. rec_result = self.run_pipeline(
  151. model_id=self.am_tf_model_id, audio_in=wav_file_path)
  152. self.check_result('test_run_with_wav_tf', rec_result)
  153. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  154. def test_run_with_pcm_tf(self):
  155. '''run with wav data
  156. '''
  157. logger.info('Run ASR test with wav data (tensorflow)...')
  158. audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
  159. rec_result = self.run_pipeline(
  160. model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
  161. self.check_result('test_run_with_pcm_tf', rec_result)
  162. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  163. def test_run_with_wav_dataset_pytorch(self):
  164. '''run with datasets, and audio format is waveform
  165. datasets directory:
  166. <dataset_path>
  167. wav
  168. test # testsets
  169. xx.wav
  170. ...
  171. dev # devsets
  172. yy.wav
  173. ...
  174. train # trainsets
  175. zz.wav
  176. ...
  177. transcript
  178. data.text # hypothesis text
  179. '''
  180. logger.info('Run ASR test with waveform dataset (pytorch)...')
  181. logger.info('Downloading waveform testsets file ...')
  182. dataset_path = download_and_untar(
  183. os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
  184. LITTLE_TESTSETS_URL, self.workspace)
  185. dataset_path = os.path.join(dataset_path, 'wav', 'test')
  186. rec_result = self.run_pipeline(
  187. model_id=self.am_pytorch_model_id, audio_in=dataset_path)
  188. self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
  189. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  190. def test_run_with_wav_dataset_tf(self):
  191. '''run with datasets, and audio format is waveform
  192. datasets directory:
  193. <dataset_path>
  194. wav
  195. test # testsets
  196. xx.wav
  197. ...
  198. dev # devsets
  199. yy.wav
  200. ...
  201. train # trainsets
  202. zz.wav
  203. ...
  204. transcript
  205. data.text # hypothesis text
  206. '''
  207. logger.info('Run ASR test with waveform dataset (tensorflow)...')
  208. logger.info('Downloading waveform testsets file ...')
  209. dataset_path = download_and_untar(
  210. os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
  211. LITTLE_TESTSETS_URL, self.workspace)
  212. dataset_path = os.path.join(dataset_path, 'wav', 'test')
  213. rec_result = self.run_pipeline(
  214. model_id=self.am_tf_model_id, audio_in=dataset_path)
  215. self.check_result('test_run_with_wav_dataset_tf', rec_result)
  216. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  217. def test_run_with_ark_dataset(self):
  218. '''run with datasets, and audio format is kaldi_ark
  219. datasets directory:
  220. <dataset_path>
  221. test # testsets
  222. data.ark
  223. data.scp
  224. data.text
  225. dev # devsets
  226. data.ark
  227. data.scp
  228. data.text
  229. train # trainsets
  230. data.ark
  231. data.scp
  232. data.text
  233. '''
  234. logger.info('Run ASR test with ark dataset (pytorch)...')
  235. logger.info('Downloading ark testsets file ...')
  236. dataset_path = download_and_untar(
  237. os.path.join(self.workspace, AISHELL1_TESTSETS_FILE),
  238. AISHELL1_TESTSETS_URL, self.workspace)
  239. dataset_path = os.path.join(dataset_path, 'test')
  240. rec_result = self.run_pipeline(
  241. model_id=self.am_pytorch_model_id, audio_in=dataset_path)
  242. self.check_result('test_run_with_ark_dataset', rec_result)
  243. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  244. def test_run_with_tfrecord_dataset(self):
  245. '''run with datasets, and audio format is tfrecord
  246. datasets directory:
  247. <dataset_path>
  248. test # testsets
  249. data.records
  250. data.idx
  251. data.text
  252. '''
  253. logger.info('Run ASR test with tfrecord dataset (tensorflow)...')
  254. logger.info('Downloading tfrecord testsets file ...')
  255. dataset_path = download_and_untar(
  256. os.path.join(self.workspace, TFRECORD_TESTSETS_FILE),
  257. TFRECORD_TESTSETS_URL, self.workspace)
  258. dataset_path = os.path.join(dataset_path, 'test')
  259. rec_result = self.run_pipeline(
  260. model_id=self.am_tf_model_id, audio_in=dataset_path)
  261. self.check_result('test_run_with_tfrecord_dataset', rec_result)
  262. if __name__ == '__main__':
  263. unittest.main()