|
- import os
- import shutil
- import stat
- from pathlib import Path
- from typing import Any, Dict, List
-
- import yaml
-
- from modelscope.metainfo import Preprocessors
- from modelscope.models.base import Model
- from modelscope.utils.constant import Fields
- from .base import Preprocessor
- from .builder import PREPROCESSORS
-
- __all__ = ['WavToLists']
-
-
- @PREPROCESSORS.register_module(
- Fields.audio, module_name=Preprocessors.wav_to_lists)
- class WavToLists(Preprocessor):
- """generate audio lists file from wav
-
- Args:
- workspace (str): store temporarily kws intermedium and result
- """
-
- def __init__(self, workspace: str = None):
- # the workspace path
- if len(workspace) == 0:
- self._workspace = os.path.join(os.getcwd(), '.tmp')
- else:
- self._workspace = workspace
-
- if not os.path.exists(self._workspace):
- os.mkdir(self._workspace)
-
- def __call__(self,
- model: Model = None,
- kws_type: str = None,
- wav_path: List[str] = None) -> Dict[str, Any]:
- """Call functions to load model and wav.
-
- Args:
- model (Model): model should be provided
- kws_type (str): kws work type: wav, neg_testsets, pos_testsets, roc
- wav_path (List[str]): wav_path[0] is positive wav path, wav_path[1] is negative wav path
- Returns:
- Dict[str, Any]: the kws result
- """
-
- assert model is not None, 'preprocess kws model should be provided'
- assert kws_type in ['wav', 'pos_testsets', 'neg_testsets', 'roc'
- ], f'preprocess kws_type {kws_type} is invalid'
- assert wav_path[0] is not None or wav_path[
- 1] is not None, 'preprocess wav_path is invalid'
-
- self._model = model
- out = self.forward(self._model.forward(), kws_type, wav_path)
- return out
-
- def forward(self, model: Dict[str, Any], kws_type: str,
- wav_path: List[str]) -> Dict[str, Any]:
- assert len(kws_type) > 0, 'preprocess kws_type is empty'
- assert len(
- model['config_path']) > 0, 'preprocess model[config_path] is empty'
- assert os.path.exists(
- model['config_path']), 'model config.yaml is absent'
-
- inputs = model.copy()
-
- inputs['kws_set'] = kws_type
- inputs['workspace'] = self._workspace
- if wav_path[0] is not None:
- inputs['pos_wav_path'] = wav_path[0]
- if wav_path[1] is not None:
- inputs['neg_wav_path'] = wav_path[1]
-
- out = self._read_config(inputs)
- out = self._generate_wav_lists(out)
-
- return out
-
- def _read_config(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- """read and parse config.yaml to get all model files
- """
-
- assert os.path.exists(
- inputs['config_path']), 'model config yaml file does not exist'
-
- config_file = open(inputs['config_path'])
- root = yaml.full_load(config_file)
- config_file.close()
-
- inputs['cfg_file'] = root['cfg_file']
- inputs['cfg_file_path'] = os.path.join(inputs['model_workspace'],
- root['cfg_file'])
- inputs['keyword_grammar'] = root['keyword_grammar']
- inputs['keyword_grammar_path'] = os.path.join(
- inputs['model_workspace'], root['keyword_grammar'])
- inputs['sample_rate'] = str(root['sample_rate'])
- inputs['kws_tool'] = root['kws_tool']
-
- if os.path.exists(
- os.path.join(inputs['workspace'], inputs['kws_tool'])):
- inputs['kws_tool_path'] = os.path.join(inputs['workspace'],
- inputs['kws_tool'])
- elif os.path.exists(os.path.join('/usr/bin', inputs['kws_tool'])):
- inputs['kws_tool_path'] = os.path.join('/usr/bin',
- inputs['kws_tool'])
- elif os.path.exists(os.path.join('/bin', inputs['kws_tool'])):
- inputs['kws_tool_path'] = os.path.join('/bin', inputs['kws_tool'])
-
- assert os.path.exists(inputs['kws_tool_path']), 'cannot find kwsbp'
- os.chmod(inputs['kws_tool_path'],
- stat.S_IXUSR + stat.S_IXGRP + stat.S_IXOTH)
-
- self._config_checking(inputs)
- return inputs
-
- def _generate_wav_lists(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- """assemble wav lists
- """
-
- if inputs['kws_set'] == 'wav':
- inputs['pos_num_thread'] = 1
- wave_scp_content: str = inputs['pos_wav_path'] + '\n'
-
- with open(os.path.join(inputs['pos_data_path'], 'wave.list'),
- 'a') as f:
- f.write(wave_scp_content)
-
- inputs['pos_wav_count'] = 1
-
- if inputs['kws_set'] in ['pos_testsets', 'roc']:
- # find all positive wave
- wav_list = []
- wav_dir = inputs['pos_wav_path']
- wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
-
- list_count: int = len(wav_list)
- inputs['pos_wav_count'] = list_count
-
- if list_count <= 128:
- inputs['pos_num_thread'] = list_count
- j: int = 0
- while j < list_count:
- wave_scp_content: str = wav_list[j] + '\n'
- wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
- j) + '.list'
- with open(wav_list_path, 'a') as f:
- f.write(wave_scp_content)
- j += 1
-
- else:
- inputs['pos_num_thread'] = 128
- j: int = 0
- k: int = 0
- while j < list_count:
- wave_scp_content: str = wav_list[j] + '\n'
- wav_list_path = inputs['pos_data_path'] + '/wave.' + str(
- k) + '.list'
- with open(wav_list_path, 'a') as f:
- f.write(wave_scp_content)
- j += 1
- k += 1
- if k >= 128:
- k = 0
-
- if inputs['kws_set'] in ['neg_testsets', 'roc']:
- # find all negative wave
- wav_list = []
- wav_dir = inputs['neg_wav_path']
- wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
-
- list_count: int = len(wav_list)
- inputs['neg_wav_count'] = list_count
-
- if list_count <= 128:
- inputs['neg_num_thread'] = list_count
- j: int = 0
- while j < list_count:
- wave_scp_content: str = wav_list[j] + '\n'
- wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
- j) + '.list'
- with open(wav_list_path, 'a') as f:
- f.write(wave_scp_content)
- j += 1
-
- else:
- inputs['neg_num_thread'] = 128
- j: int = 0
- k: int = 0
- while j < list_count:
- wave_scp_content: str = wav_list[j] + '\n'
- wav_list_path = inputs['neg_data_path'] + '/wave.' + str(
- k) + '.list'
- with open(wav_list_path, 'a') as f:
- f.write(wave_scp_content)
- j += 1
- k += 1
- if k >= 128:
- k = 0
-
- return inputs
-
- def _recursion_dir_all_wave(self, wav_list,
- dir_path: str) -> Dict[str, Any]:
- dir_files = os.listdir(dir_path)
- for file in dir_files:
- file_path = os.path.join(dir_path, file)
- if os.path.isfile(file_path):
- if file_path.endswith('.wav') or file_path.endswith('.WAV'):
- wav_list.append(file_path)
- elif os.path.isdir(file_path):
- self._recursion_dir_all_wave(wav_list, file_path)
-
- return wav_list
-
- def _config_checking(self, inputs: Dict[str, Any]):
-
- if inputs['kws_set'] in ['wav', 'pos_testsets', 'roc']:
- inputs['pos_data_path'] = os.path.join(inputs['workspace'],
- 'pos_data')
- if not os.path.exists(inputs['pos_data_path']):
- os.mkdir(inputs['pos_data_path'])
- else:
- shutil.rmtree(inputs['pos_data_path'])
- os.mkdir(inputs['pos_data_path'])
-
- inputs['pos_dump_path'] = os.path.join(inputs['workspace'],
- 'pos_dump')
- if not os.path.exists(inputs['pos_dump_path']):
- os.mkdir(inputs['pos_dump_path'])
- else:
- shutil.rmtree(inputs['pos_dump_path'])
- os.mkdir(inputs['pos_dump_path'])
-
- if inputs['kws_set'] in ['neg_testsets', 'roc']:
- inputs['neg_data_path'] = os.path.join(inputs['workspace'],
- 'neg_data')
- if not os.path.exists(inputs['neg_data_path']):
- os.mkdir(inputs['neg_data_path'])
- else:
- shutil.rmtree(inputs['neg_data_path'])
- os.mkdir(inputs['neg_data_path'])
-
- inputs['neg_dump_path'] = os.path.join(inputs['workspace'],
- 'neg_dump')
- if not os.path.exists(inputs['neg_dump_path']):
- os.mkdir(inputs['neg_dump_path'])
- else:
- shutil.rmtree(inputs['neg_dump_path'])
- os.mkdir(inputs['neg_dump_path'])
|