Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10275823master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d | |||||
| size 2327764 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 | |||||
| size 68524 | |||||
| @@ -285,6 +285,7 @@ class Trainers(object): | |||||
| # audio trainers | # audio trainers | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||||
| class Preprocessors(object): | class Preprocessors(object): | ||||
| @@ -1,15 +1,14 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| from typing import Dict | |||||
| import torch | |||||
| from typing import Dict, Optional | |||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models import TorchModel | from modelscope.models import TorchModel | ||||
| from modelscope.models.base import Tensor | from modelscope.models.base import Tensor | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.audio.audio_utils import update_conf | |||||
| from modelscope.utils.constant import Tasks | |||||
| from .fsmn_sele_v2 import FSMNSeleNetV2 | from .fsmn_sele_v2 import FSMNSeleNetV2 | ||||
| @@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||||
| MODEL_TXT = 'model.txt' | MODEL_TXT = 'model.txt' | ||||
| SC_CONFIG = 'sound_connect.conf' | SC_CONFIG = 'sound_connect.conf' | ||||
| SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| training: Optional[bool] = False, | |||||
| *args, | |||||
| **kwargs): | |||||
| """initialize the dfsmn model from the `model_dir` path. | """initialize the dfsmn model from the `model_dir` path. | ||||
| Args: | Args: | ||||
| model_dir (str): the model path. | model_dir (str): the model path. | ||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||||
| model_bin_file = os.path.join(model_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| self._model = None | |||||
| if os.path.exists(model_bin_file): | |||||
| kwargs.pop('device') | |||||
| self._model = FSMNSeleNetV2(*args, **kwargs) | |||||
| checkpoint = torch.load(model_bin_file) | |||||
| self._model.load_state_dict(checkpoint, strict=False) | |||||
| self._sc = None | |||||
| if os.path.exists(model_txt_file): | |||||
| with open(sc_config_file) as f: | |||||
| lines = f.readlines() | |||||
| with open(sc_config_file, 'w') as f: | |||||
| for line in lines: | |||||
| if self.SC_CONF_ITEM_KWS_MODEL in line: | |||||
| line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||||
| model_txt_file) | |||||
| f.write(line) | |||||
| import py_sound_connect | |||||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||||
| self.size_in = self._sc.bytesPerBlockIn() | |||||
| self.size_out = self._sc.bytesPerBlockOut() | |||||
| if self._model is None and self._sc is None: | |||||
| raise Exception( | |||||
| f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||||
| ) | |||||
| if training: | |||||
| self.model = FSMNSeleNetV2(*args, **kwargs) | |||||
| else: | |||||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||||
| self._sc = None | |||||
| if os.path.exists(model_txt_file): | |||||
| conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||||
| update_conf(sc_config_file, sc_config_file, conf_dict) | |||||
| import py_sound_connect | |||||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||||
| self.size_in = self._sc.bytesPerBlockIn() | |||||
| self.size_out = self._sc.bytesPerBlockOut() | |||||
| else: | |||||
| raise Exception( | |||||
| f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||||
| ) | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| ... | |||||
| return self.model.forward(input) | |||||
| def forward_decode(self, data: bytes): | def forward_decode(self, data: bytes): | ||||
| result = {'pcm': self._sc.process(data, self.size_out)} | result = {'pcm': self._sc.process(data, self.size_out)} | ||||
| @@ -0,0 +1,21 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .kws_farfield_dataset import KWSDataset, KWSDataLoader | |||||
| else: | |||||
| _import_structure = { | |||||
| 'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,280 @@ | |||||
| """ | |||||
| Used to prepare simulated data. | |||||
| """ | |||||
| import math | |||||
| import os.path | |||||
| import queue | |||||
| import threading | |||||
| import time | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| BLOCK_DEC = 2 | |||||
| BLOCK_CAT = 3 | |||||
| FBANK_SIZE = 40 | |||||
| LABEL_SIZE = 1 | |||||
| LABEL_GAIN = 100.0 | |||||
| class KWSDataset: | |||||
| """ | |||||
| dataset for keyword spotting and vad | |||||
| conf_basetrain: basetrain configure file path | |||||
| conf_finetune: finetune configure file path, null allowed | |||||
| numworkers: no. of workers | |||||
| basetrainratio: basetrain workers ratio | |||||
| numclasses: no. of nn output classes, 2 classes to generate vad label | |||||
| blockdec: block decimation | |||||
| blockcat: block concatenation | |||||
| """ | |||||
| def __init__(self, | |||||
| conf_basetrain, | |||||
| conf_finetune, | |||||
| numworkers, | |||||
| basetrainratio, | |||||
| numclasses, | |||||
| blockdec=BLOCK_CAT, | |||||
| blockcat=BLOCK_CAT): | |||||
| super().__init__() | |||||
| self.numclasses = numclasses | |||||
| self.blockdec = blockdec | |||||
| self.blockcat = blockcat | |||||
| self.sims_base = [] | |||||
| self.sims_senior = [] | |||||
| self.setup_sims(conf_basetrain, conf_finetune, numworkers, | |||||
| basetrainratio) | |||||
| def release(self): | |||||
| for sim in self.sims_base: | |||||
| del sim | |||||
| for sim in self.sims_senior: | |||||
| del sim | |||||
| del self.base_conf | |||||
| del self.senior_conf | |||||
| logger.info('KWSDataset: Released.') | |||||
| def setup_sims(self, conf_basetrain, conf_finetune, numworkers, | |||||
| basetrainratio): | |||||
| if not os.path.exists(conf_basetrain): | |||||
| raise ValueError(f'{conf_basetrain} does not exist!') | |||||
| if not os.path.exists(conf_finetune): | |||||
| raise ValueError(f'{conf_finetune} does not exist!') | |||||
| import py_sound_connect | |||||
| logger.info('KWSDataset init SoundConnect...') | |||||
| num_base = math.ceil(numworkers * basetrainratio) | |||||
| num_senior = numworkers - num_base | |||||
| # hold by fields to avoid python releasing conf object | |||||
| self.base_conf = py_sound_connect.ConfigFile(conf_basetrain) | |||||
| self.senior_conf = py_sound_connect.ConfigFile(conf_finetune) | |||||
| for i in range(num_base): | |||||
| fs = py_sound_connect.FeatSimuKWS(self.base_conf.params) | |||||
| self.sims_base.append(fs) | |||||
| for i in range(num_senior): | |||||
| self.sims_senior.append( | |||||
| py_sound_connect.FeatSimuKWS(self.senior_conf.params)) | |||||
| logger.info('KWSDataset init SoundConnect finished.') | |||||
| def getBatch(self, id): | |||||
| """ | |||||
| Generate a data batch | |||||
| Args: | |||||
| id: worker id | |||||
| Return: time x channel x feature, label | |||||
| """ | |||||
| fs = self.get_sim(id) | |||||
| fs.processBatch() | |||||
| # get multi-channel feature vector size | |||||
| featsize = fs.featSize() | |||||
| # get label vector size | |||||
| labelsize = fs.labelSize() | |||||
| # get minibatch size (time dimension) | |||||
| # batchsize = fs.featBatchSize() | |||||
| # no. of fe output channels | |||||
| numchs = featsize // FBANK_SIZE | |||||
| # get raw data | |||||
| fs_feat = fs.feat() | |||||
| data = np.frombuffer(fs_feat, dtype='float32') | |||||
| data = data.reshape((-1, featsize + labelsize)) | |||||
| # convert float label to int | |||||
| label = data[:, FBANK_SIZE * numchs:] | |||||
| if self.numclasses == 2: | |||||
| # generate vad label | |||||
| label[label > 0.0] = 1.0 | |||||
| else: | |||||
| # generate kws label | |||||
| label = np.round(label * LABEL_GAIN) | |||||
| label[label > self.numclasses - 1] = 0.0 | |||||
| # decimated size | |||||
| size1 = int(np.ceil( | |||||
| label.shape[0] / self.blockdec)) - self.blockcat + 1 | |||||
| # label decimation | |||||
| label1 = np.zeros((size1, LABEL_SIZE), dtype='float32') | |||||
| for tau in range(size1): | |||||
| label1[tau, :] = label[(tau + self.blockcat // 2) | |||||
| * self.blockdec, :] | |||||
| # feature decimation and concatenation | |||||
| # time x channel x feature | |||||
| featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat), | |||||
| dtype='float32') | |||||
| for n in range(numchs): | |||||
| feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)] | |||||
| for tau in range(size1): | |||||
| for i in range(self.blockcat): | |||||
| featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \ | |||||
| feat[(tau + i) * self.blockdec, :] | |||||
| return torch.from_numpy(featall), torch.from_numpy(label1).long() | |||||
| def get_sim(self, id): | |||||
| num_base = len(self.sims_base) | |||||
| if id < num_base: | |||||
| fs = self.sims_base[id] | |||||
| else: | |||||
| fs = self.sims_senior[id - num_base] | |||||
| return fs | |||||
| class Worker(threading.Thread): | |||||
| """ | |||||
| id: worker id | |||||
| dataset: the dataset | |||||
| pool: queue as the global data buffer | |||||
| """ | |||||
| def __init__(self, id, dataset, pool): | |||||
| threading.Thread.__init__(self) | |||||
| self.id = id | |||||
| self.dataset = dataset | |||||
| self.pool = pool | |||||
| self.isrun = True | |||||
| self.nn = 0 | |||||
| def run(self): | |||||
| while self.isrun: | |||||
| self.nn += 1 | |||||
| logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1') | |||||
| # get simulated minibatch | |||||
| if self.isrun: | |||||
| data = self.dataset.getBatch(self.id) | |||||
| logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2') | |||||
| # put data into buffer | |||||
| if self.isrun: | |||||
| self.pool.put(data) | |||||
| logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3') | |||||
| logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id)) | |||||
| def stopWorker(self): | |||||
| """ | |||||
| stop the worker thread | |||||
| """ | |||||
| self.isrun = False | |||||
| class KWSDataLoader: | |||||
| """ | |||||
| dataset: the dataset reference | |||||
| batchsize: data batch size | |||||
| numworkers: no. of workers | |||||
| prefetch: prefetch factor | |||||
| """ | |||||
| def __init__(self, dataset, batchsize, numworkers, prefetch=2): | |||||
| self.dataset = dataset | |||||
| self.batchsize = batchsize | |||||
| self.datamap = {} | |||||
| self.isrun = True | |||||
| # data queue | |||||
| self.pool = queue.Queue(batchsize * prefetch) | |||||
| # initialize workers | |||||
| self.workerlist = [] | |||||
| for id in range(numworkers): | |||||
| w = Worker(id, dataset, self.pool) | |||||
| self.workerlist.append(w) | |||||
| def __iter__(self): | |||||
| return self | |||||
| def __next__(self): | |||||
| while self.isrun: | |||||
| # get data from common data pool | |||||
| data = self.pool.get() | |||||
| self.pool.task_done() | |||||
| # group minibatches with the same shape | |||||
| key = str(data[0].shape) | |||||
| batchl = self.datamap.get(key) | |||||
| if batchl is None: | |||||
| batchl = [] | |||||
| self.datamap.update({key: batchl}) | |||||
| batchl.append(data) | |||||
| # a full data batch collected | |||||
| if len(batchl) >= self.batchsize: | |||||
| featbatch = [] | |||||
| labelbatch = [] | |||||
| for feat, label in batchl: | |||||
| featbatch.append(feat) | |||||
| labelbatch.append(label) | |||||
| batchl.clear() | |||||
| feattensor = torch.stack(featbatch, dim=0) | |||||
| labeltensor = torch.stack(labelbatch, dim=0) | |||||
| if feattensor.shape[-2] == 1: | |||||
| logger.debug('KWSDataLoader: Basetrain batch.') | |||||
| else: | |||||
| logger.debug('KWSDataLoader: Finetune batch.') | |||||
| return feattensor, labeltensor | |||||
| return None, None | |||||
| def start(self): | |||||
| """ | |||||
| start multi-thread data loader | |||||
| """ | |||||
| for w in self.workerlist: | |||||
| w.start() | |||||
| def stop(self): | |||||
| """ | |||||
| stop data loader | |||||
| """ | |||||
| logger.info('KWSDataLoader: Stopping...') | |||||
| self.isrun = False | |||||
| for w in self.workerlist: | |||||
| w.stopWorker() | |||||
| while not self.pool.empty(): | |||||
| self.pool.get(block=True, timeout=0.001) | |||||
| # wait workers terminated | |||||
| for w in self.workerlist: | |||||
| while not self.pool.empty(): | |||||
| self.pool.get(block=True, timeout=0.001) | |||||
| w.join() | |||||
| logger.info('KWSDataLoader: All worker stopped.') | |||||
| @@ -0,0 +1,279 @@ | |||||
| import datetime | |||||
| import math | |||||
| import os | |||||
| from typing import Callable, Dict, Optional | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import nn as nn | |||||
| from torch import optim as optim | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models import Model, TorchModel | |||||
| from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.utils.audio.audio_utils import update_conf | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
| from modelscope.utils.data_utils import to_device | |||||
| from modelscope.utils.device import create_device | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, | |||||
| init_dist, is_master) | |||||
| logger = get_logger() | |||||
| BASETRAIN_CONF_EASY = 'basetrain_easy' | |||||
| BASETRAIN_CONF_NORMAL = 'basetrain_normal' | |||||
| BASETRAIN_CONF_HARD = 'basetrain_hard' | |||||
| FINETUNE_CONF_EASY = 'finetune_easy' | |||||
| FINETUNE_CONF_NORMAL = 'finetune_normal' | |||||
| FINETUNE_CONF_HARD = 'finetune_hard' | |||||
| EASY_RATIO = 0.1 | |||||
| NORMAL_RATIO = 0.6 | |||||
| HARD_RATIO = 0.3 | |||||
| BASETRAIN_RATIO = 0.5 | |||||
| @TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield) | |||||
| class KWSFarfieldTrainer(BaseTrainer): | |||||
| DEFAULT_WORK_DIR = './work_dir' | |||||
| conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY, | |||||
| BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL, | |||||
| BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD) | |||||
| def __init__(self, | |||||
| model: str, | |||||
| work_dir: str, | |||||
| cfg_file: Optional[str] = None, | |||||
| arg_parse_fn: Optional[Callable] = None, | |||||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| custom_conf: Optional[dict] = None, | |||||
| **kwargs): | |||||
| if isinstance(model, str): | |||||
| if os.path.exists(model): | |||||
| self.model_dir = model if os.path.isdir( | |||||
| model) else os.path.dirname(model) | |||||
| else: | |||||
| self.model_dir = snapshot_download( | |||||
| model, revision=model_revision) | |||||
| if cfg_file is None: | |||||
| cfg_file = os.path.join(self.model_dir, | |||||
| ModelFile.CONFIGURATION) | |||||
| else: | |||||
| assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' | |||||
| self.model_dir = os.path.dirname(cfg_file) | |||||
| super().__init__(cfg_file, arg_parse_fn) | |||||
| self.model = self.build_model() | |||||
| self.work_dir = work_dir | |||||
| # the number of model output dimension | |||||
| # should update config outside the trainer, if user need more wake word | |||||
| self._num_classes = self.cfg.model.num_syn | |||||
| if kwargs.get('launcher', None) is not None: | |||||
| init_dist(kwargs['launcher']) | |||||
| _, world_size = get_dist_info() | |||||
| self._dist = world_size > 1 | |||||
| device_name = kwargs.get('device', 'gpu') | |||||
| if self._dist: | |||||
| local_rank = get_local_rank() | |||||
| device_name = f'cuda:{local_rank}' | |||||
| self.device = create_device(device_name) | |||||
| # model placement | |||||
| if self.device.type == 'cuda': | |||||
| self.model.to(self.device) | |||||
| if 'max_epochs' not in kwargs: | |||||
| assert hasattr( | |||||
| self.cfg.train, 'max_epochs' | |||||
| ), 'max_epochs is missing from the configuration file' | |||||
| self._max_epochs = self.cfg.train.max_epochs | |||||
| else: | |||||
| self._max_epochs = kwargs['max_epochs'] | |||||
| self._train_iters = kwargs.get('train_iters_per_epoch', None) | |||||
| self._val_iters = kwargs.get('val_iters_per_epoch', None) | |||||
| if self._train_iters is None: | |||||
| self._train_iters = self.cfg.train.train_iters_per_epoch | |||||
| if self._val_iters is None: | |||||
| self._val_iters = self.cfg.evaluation.val_iters_per_epoch | |||||
| dataloader_config = self.cfg.train.dataloader | |||||
| self._threads = kwargs.get('workers', None) | |||||
| if self._threads is None: | |||||
| self._threads = dataloader_config.workers_per_gpu | |||||
| self._single_rate = BASETRAIN_RATIO | |||||
| if 'single_rate' in kwargs: | |||||
| self._single_rate = kwargs['single_rate'] | |||||
| self._batch_size = dataloader_config.batch_size_per_gpu | |||||
| if 'model_bin' in kwargs: | |||||
| model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | |||||
| checkpoint = torch.load(model_bin_file) | |||||
| self.model.load_state_dict(checkpoint) | |||||
| # build corresponding optimizer and loss function | |||||
| lr = self.cfg.train.optimizer.lr | |||||
| self.optimizer = optim.Adam(self.model.parameters(), lr) | |||||
| self.loss_fn = nn.CrossEntropyLoss() | |||||
| self.data_val = None | |||||
| self.json_log_path = os.path.join(self.work_dir, | |||||
| '{}.log.json'.format(self.timestamp)) | |||||
| self.conf_files = [] | |||||
| for conf_key in self.conf_keys: | |||||
| template_file = os.path.join(self.model_dir, conf_key) | |||||
| conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') | |||||
| update_conf(template_file, conf_file, custom_conf[conf_key]) | |||||
| self.conf_files.append(conf_file) | |||||
| self._current_epoch = 0 | |||||
| self.stages = (math.floor(self._max_epochs * EASY_RATIO), | |||||
| math.floor(self._max_epochs * NORMAL_RATIO), | |||||
| math.floor(self._max_epochs * HARD_RATIO)) | |||||
| def build_model(self) -> nn.Module: | |||||
| """ Instantiate a pytorch model and return. | |||||
| By default, we will create a model using config from configuration file. You can | |||||
| override this method in a subclass. | |||||
| """ | |||||
| model = Model.from_pretrained( | |||||
| self.model_dir, cfg_dict=self.cfg, training=True) | |||||
| if isinstance(model, TorchModel) and hasattr(model, 'model'): | |||||
| return model.model | |||||
| elif isinstance(model, nn.Module): | |||||
| return model | |||||
| def train(self, *args, **kwargs): | |||||
| if not self.data_val: | |||||
| self.gen_val() | |||||
| logger.info('Start training...') | |||||
| totaltime = datetime.datetime.now() | |||||
| for stage, num_epoch in enumerate(self.stages): | |||||
| self.run_stage(stage, num_epoch) | |||||
| # total time spent | |||||
| totaltime = datetime.datetime.now() - totaltime | |||||
| logger.info('Total time spent: {:.2f} hours\n'.format( | |||||
| totaltime.total_seconds() / 3600.0)) | |||||
| def run_stage(self, stage, num_epoch): | |||||
| """ | |||||
| Run training stages with correspond data | |||||
| Args: | |||||
| stage: id of stage | |||||
| num_epoch: the number of epoch to run in this stage | |||||
| """ | |||||
| if num_epoch <= 0: | |||||
| logger.warning(f'Invalid epoch number, stage {stage} exit!') | |||||
| return | |||||
| logger.info(f'Starting stage {stage}...') | |||||
| dataset, dataloader = self.create_dataloader( | |||||
| self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) | |||||
| it = iter(dataloader) | |||||
| for _ in range(num_epoch): | |||||
| self._current_epoch += 1 | |||||
| epochtime = datetime.datetime.now() | |||||
| logger.info('Start epoch %d...', self._current_epoch) | |||||
| loss_train_epoch = 0.0 | |||||
| validbatchs = 0 | |||||
| for bi in range(self._train_iters): | |||||
| # prepare data | |||||
| feat, label = next(it) | |||||
| label = torch.reshape(label, (-1, )) | |||||
| feat = to_device(feat, self.device) | |||||
| label = to_device(label, self.device) | |||||
| # apply model | |||||
| self.optimizer.zero_grad() | |||||
| predict = self.model(feat) | |||||
| # calculate loss | |||||
| loss = self.loss_fn( | |||||
| torch.reshape(predict, (-1, self._num_classes)), label) | |||||
| if not np.isnan(loss.item()): | |||||
| loss.backward() | |||||
| self.optimizer.step() | |||||
| loss_train_epoch += loss.item() | |||||
| validbatchs += 1 | |||||
| train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format( | |||||
| self._current_epoch, self._max_epochs, bi + 1, | |||||
| self._train_iters, loss.item()) | |||||
| logger.info(train_result) | |||||
| self._dump_log(train_result) | |||||
| # average training loss in one epoch | |||||
| loss_train_epoch /= validbatchs | |||||
| loss_val_epoch = self.evaluate('') | |||||
| val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format( | |||||
| self._current_epoch, loss_train_epoch, loss_val_epoch) | |||||
| logger.info(val_result) | |||||
| self._dump_log(val_result) | |||||
| # check point | |||||
| ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | |||||
| self._current_epoch, loss_train_epoch, loss_val_epoch) | |||||
| torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) | |||||
| # time spent per epoch | |||||
| epochtime = datetime.datetime.now() - epochtime | |||||
| logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | |||||
| self._current_epoch, | |||||
| epochtime.total_seconds() / 3600.0)) | |||||
| dataloader.stop() | |||||
| dataset.release() | |||||
| logger.info(f'Stage {stage} is finished.') | |||||
| def gen_val(self): | |||||
| """ | |||||
| generate validation set | |||||
| """ | |||||
| logger.info('Start generating validation set...') | |||||
| dataset, dataloader = self.create_dataloader(self.conf_files[2], | |||||
| self.conf_files[3]) | |||||
| it = iter(dataloader) | |||||
| self.data_val = [] | |||||
| for bi in range(self._val_iters): | |||||
| logger.info('Iterating validation data %d', bi) | |||||
| feat, label = next(it) | |||||
| label = torch.reshape(label, (-1, )) | |||||
| self.data_val.append([feat, label]) | |||||
| dataloader.stop() | |||||
| dataset.release() | |||||
| logger.info('Finish generating validation set!') | |||||
| def create_dataloader(self, base_path, finetune_path): | |||||
| dataset = KWSDataset(base_path, finetune_path, self._threads, | |||||
| self._single_rate, self._num_classes) | |||||
| dataloader = KWSDataLoader( | |||||
| dataset, batchsize=self._batch_size, numworkers=self._threads) | |||||
| dataloader.start() | |||||
| return dataset, dataloader | |||||
| def evaluate(self, checkpoint_path: str, *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| logger.info('Start validation...') | |||||
| loss_val_epoch = 0.0 | |||||
| with torch.no_grad(): | |||||
| for feat, label in self.data_val: | |||||
| feat = to_device(feat, self.device) | |||||
| label = to_device(label, self.device) | |||||
| # apply model | |||||
| predict = self.model(feat) | |||||
| # calculate loss | |||||
| loss = self.loss_fn( | |||||
| torch.reshape(predict, (-1, self._num_classes)), label) | |||||
| loss_val_epoch += loss.item() | |||||
| logger.info('Finish validation.') | |||||
| return loss_val_epoch / self._val_iters | |||||
| def _dump_log(self, msg): | |||||
| if is_master(): | |||||
| with open(self.json_log_path, 'a+') as f: | |||||
| f.write(msg) | |||||
| f.write('\n') | |||||
| @@ -1,4 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import re | |||||
| import struct | import struct | ||||
| from typing import Union | from typing import Union | ||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||
| @@ -37,6 +38,23 @@ def audio_norm(x): | |||||
| return x | return x | ||||
| def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||||
| def repl(matched): | |||||
| key = matched.group(1) | |||||
| if key in conf_item: | |||||
| return conf_item[key] | |||||
| else: | |||||
| return None | |||||
| with open(origin_config_file) as f: | |||||
| lines = f.readlines() | |||||
| with open(new_config_file, 'w') as f: | |||||
| for line in lines: | |||||
| line = re.sub(r'\$\{(.*)\}', repl, line) | |||||
| f.write(line) | |||||
| def extract_pcm_from_wav(wav: bytes) -> bytes: | def extract_pcm_from_wav(wav: bytes) -> bytes: | ||||
| data = wav | data = wav | ||||
| if len(data) > 44: | if len(data) > 44: | ||||
| @@ -14,7 +14,11 @@ nltk | |||||
| numpy<=1.18 | numpy<=1.18 | ||||
| # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | ||||
| protobuf>3,<3.21.0 | protobuf>3,<3.21.0 | ||||
| py_sound_connect | |||||
| ptflops | |||||
| py_sound_connect>=0.1 | |||||
| pytorch_wavelets | |||||
| PyWavelets>=1.0.0 | |||||
| scikit-learn | |||||
| SoundFile>0.10 | SoundFile>0.10 | ||||
| sox | sox | ||||
| torchaudio | torchaudio | ||||
| @@ -0,0 +1,85 @@ | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.test_utils import test_level | |||||
| POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav' | |||||
| NEG_FILE = 'data/test/audios/speech_with_noise.wav' | |||||
| NOISE_FILE = 'data/test/audios/speech_with_noise.wav' | |||||
| INTERF_FILE = 'data/test/audios/speech_with_noise.wav' | |||||
| REF_FILE = 'data/test/audios/farend_speech.wav' | |||||
| NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav' | |||||
| class TestKwsFarfieldTrainer(unittest.TestCase): | |||||
| REVISION = 'beta' | |||||
| def setUp(self): | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| print(f'tmp dir: {self.tmp_dir}') | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' | |||||
| train_pos_list = self.create_list('pos.list', POS_FILE) | |||||
| train_neg_list = self.create_list('neg.list', NEG_FILE) | |||||
| train_noise1_list = self.create_list('noise.list', NOISE_FILE) | |||||
| train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE) | |||||
| train_interf_list = self.create_list('interf.list', INTERF_FILE) | |||||
| train_ref_list = self.create_list('ref.list', REF_FILE) | |||||
| base_dict = dict( | |||||
| train_pos_list=train_pos_list, | |||||
| train_neg_list=train_neg_list, | |||||
| train_noise1_list=train_noise1_list) | |||||
| fintune_dict = dict( | |||||
| train_pos_list=train_pos_list, | |||||
| train_neg_list=train_neg_list, | |||||
| train_noise1_list=train_noise1_list, | |||||
| train_noise2_type='1', | |||||
| train_noise1_ratio='0.2', | |||||
| train_noise2_list=train_noise2_list, | |||||
| train_interf_list=train_interf_list, | |||||
| train_ref_list=train_ref_list) | |||||
| self.custom_conf = dict( | |||||
| basetrain_easy=base_dict, | |||||
| basetrain_normal=base_dict, | |||||
| basetrain_hard=base_dict, | |||||
| finetune_easy=fintune_dict, | |||||
| finetune_normal=fintune_dict, | |||||
| finetune_hard=fintune_dict) | |||||
| def create_list(self, list_name, audio_file): | |||||
| pos_list_file = os.path.join(self.tmp_dir, list_name) | |||||
| with open(pos_list_file, 'w') as f: | |||||
| for i in range(10): | |||||
| f.write(f'{os.path.join(os.getcwd(), audio_file)}\n') | |||||
| train_pos_list = f'{pos_list_file}, 1.0' | |||||
| return train_pos_list | |||||
| def tearDown(self) -> None: | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_normal(self): | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| work_dir=self.tmp_dir, | |||||
| model_revision=self.REVISION, | |||||
| workers=2, | |||||
| max_epochs=2, | |||||
| train_iters_per_epoch=2, | |||||
| val_iters_per_epoch=1, | |||||
| custom_conf=self.custom_conf) | |||||
| trainer = build_trainer( | |||||
| Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files, | |||||
| f'work_dir:{self.tmp_dir}') | |||||