Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10275823master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d | |||
| size 2327764 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150 | |||
| size 68524 | |||
| @@ -285,6 +285,7 @@ class Trainers(object): | |||
| # audio trainers | |||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | |||
| speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' | |||
| class Preprocessors(object): | |||
| @@ -1,15 +1,14 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Dict | |||
| import torch | |||
| from typing import Dict, Optional | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.base import Tensor | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.audio.audio_utils import update_conf | |||
| from modelscope.utils.constant import Tasks | |||
| from .fsmn_sele_v2 import FSMNSeleNetV2 | |||
| @@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||
| MODEL_TXT = 'model.txt' | |||
| SC_CONFIG = 'sound_connect.conf' | |||
| SC_CONF_ITEM_KWS_MODEL = '${kws_model}' | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| def __init__(self, | |||
| model_dir: str, | |||
| training: Optional[bool] = False, | |||
| *args, | |||
| **kwargs): | |||
| """initialize the dfsmn model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
| model_bin_file = os.path.join(model_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||
| self._model = None | |||
| if os.path.exists(model_bin_file): | |||
| kwargs.pop('device') | |||
| self._model = FSMNSeleNetV2(*args, **kwargs) | |||
| checkpoint = torch.load(model_bin_file) | |||
| self._model.load_state_dict(checkpoint, strict=False) | |||
| self._sc = None | |||
| if os.path.exists(model_txt_file): | |||
| with open(sc_config_file) as f: | |||
| lines = f.readlines() | |||
| with open(sc_config_file, 'w') as f: | |||
| for line in lines: | |||
| if self.SC_CONF_ITEM_KWS_MODEL in line: | |||
| line = line.replace(self.SC_CONF_ITEM_KWS_MODEL, | |||
| model_txt_file) | |||
| f.write(line) | |||
| import py_sound_connect | |||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
| self.size_in = self._sc.bytesPerBlockIn() | |||
| self.size_out = self._sc.bytesPerBlockOut() | |||
| if self._model is None and self._sc is None: | |||
| raise Exception( | |||
| f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.' | |||
| ) | |||
| if training: | |||
| self.model = FSMNSeleNetV2(*args, **kwargs) | |||
| else: | |||
| sc_config_file = os.path.join(model_dir, self.SC_CONFIG) | |||
| model_txt_file = os.path.join(model_dir, self.MODEL_TXT) | |||
| self._sc = None | |||
| if os.path.exists(model_txt_file): | |||
| conf_dict = dict(mode=56542, kws_model=model_txt_file) | |||
| update_conf(sc_config_file, sc_config_file, conf_dict) | |||
| import py_sound_connect | |||
| self._sc = py_sound_connect.SoundConnect(sc_config_file) | |||
| self.size_in = self._sc.bytesPerBlockIn() | |||
| self.size_out = self._sc.bytesPerBlockOut() | |||
| else: | |||
| raise Exception( | |||
| f'Invalid model directory! Failed to load model file: {model_txt_file}.' | |||
| ) | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| ... | |||
| return self.model.forward(input) | |||
| def forward_decode(self, data: bytes): | |||
| result = {'pcm': self._sc.process(data, self.size_out)} | |||
| @@ -0,0 +1,21 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .kws_farfield_dataset import KWSDataset, KWSDataLoader | |||
| else: | |||
| _import_structure = { | |||
| 'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,280 @@ | |||
| """ | |||
| Used to prepare simulated data. | |||
| """ | |||
| import math | |||
| import os.path | |||
| import queue | |||
| import threading | |||
| import time | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| BLOCK_DEC = 2 | |||
| BLOCK_CAT = 3 | |||
| FBANK_SIZE = 40 | |||
| LABEL_SIZE = 1 | |||
| LABEL_GAIN = 100.0 | |||
| class KWSDataset: | |||
| """ | |||
| dataset for keyword spotting and vad | |||
| conf_basetrain: basetrain configure file path | |||
| conf_finetune: finetune configure file path, null allowed | |||
| numworkers: no. of workers | |||
| basetrainratio: basetrain workers ratio | |||
| numclasses: no. of nn output classes, 2 classes to generate vad label | |||
| blockdec: block decimation | |||
| blockcat: block concatenation | |||
| """ | |||
| def __init__(self, | |||
| conf_basetrain, | |||
| conf_finetune, | |||
| numworkers, | |||
| basetrainratio, | |||
| numclasses, | |||
| blockdec=BLOCK_CAT, | |||
| blockcat=BLOCK_CAT): | |||
| super().__init__() | |||
| self.numclasses = numclasses | |||
| self.blockdec = blockdec | |||
| self.blockcat = blockcat | |||
| self.sims_base = [] | |||
| self.sims_senior = [] | |||
| self.setup_sims(conf_basetrain, conf_finetune, numworkers, | |||
| basetrainratio) | |||
| def release(self): | |||
| for sim in self.sims_base: | |||
| del sim | |||
| for sim in self.sims_senior: | |||
| del sim | |||
| del self.base_conf | |||
| del self.senior_conf | |||
| logger.info('KWSDataset: Released.') | |||
| def setup_sims(self, conf_basetrain, conf_finetune, numworkers, | |||
| basetrainratio): | |||
| if not os.path.exists(conf_basetrain): | |||
| raise ValueError(f'{conf_basetrain} does not exist!') | |||
| if not os.path.exists(conf_finetune): | |||
| raise ValueError(f'{conf_finetune} does not exist!') | |||
| import py_sound_connect | |||
| logger.info('KWSDataset init SoundConnect...') | |||
| num_base = math.ceil(numworkers * basetrainratio) | |||
| num_senior = numworkers - num_base | |||
| # hold by fields to avoid python releasing conf object | |||
| self.base_conf = py_sound_connect.ConfigFile(conf_basetrain) | |||
| self.senior_conf = py_sound_connect.ConfigFile(conf_finetune) | |||
| for i in range(num_base): | |||
| fs = py_sound_connect.FeatSimuKWS(self.base_conf.params) | |||
| self.sims_base.append(fs) | |||
| for i in range(num_senior): | |||
| self.sims_senior.append( | |||
| py_sound_connect.FeatSimuKWS(self.senior_conf.params)) | |||
| logger.info('KWSDataset init SoundConnect finished.') | |||
| def getBatch(self, id): | |||
| """ | |||
| Generate a data batch | |||
| Args: | |||
| id: worker id | |||
| Return: time x channel x feature, label | |||
| """ | |||
| fs = self.get_sim(id) | |||
| fs.processBatch() | |||
| # get multi-channel feature vector size | |||
| featsize = fs.featSize() | |||
| # get label vector size | |||
| labelsize = fs.labelSize() | |||
| # get minibatch size (time dimension) | |||
| # batchsize = fs.featBatchSize() | |||
| # no. of fe output channels | |||
| numchs = featsize // FBANK_SIZE | |||
| # get raw data | |||
| fs_feat = fs.feat() | |||
| data = np.frombuffer(fs_feat, dtype='float32') | |||
| data = data.reshape((-1, featsize + labelsize)) | |||
| # convert float label to int | |||
| label = data[:, FBANK_SIZE * numchs:] | |||
| if self.numclasses == 2: | |||
| # generate vad label | |||
| label[label > 0.0] = 1.0 | |||
| else: | |||
| # generate kws label | |||
| label = np.round(label * LABEL_GAIN) | |||
| label[label > self.numclasses - 1] = 0.0 | |||
| # decimated size | |||
| size1 = int(np.ceil( | |||
| label.shape[0] / self.blockdec)) - self.blockcat + 1 | |||
| # label decimation | |||
| label1 = np.zeros((size1, LABEL_SIZE), dtype='float32') | |||
| for tau in range(size1): | |||
| label1[tau, :] = label[(tau + self.blockcat // 2) | |||
| * self.blockdec, :] | |||
| # feature decimation and concatenation | |||
| # time x channel x feature | |||
| featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat), | |||
| dtype='float32') | |||
| for n in range(numchs): | |||
| feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)] | |||
| for tau in range(size1): | |||
| for i in range(self.blockcat): | |||
| featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \ | |||
| feat[(tau + i) * self.blockdec, :] | |||
| return torch.from_numpy(featall), torch.from_numpy(label1).long() | |||
| def get_sim(self, id): | |||
| num_base = len(self.sims_base) | |||
| if id < num_base: | |||
| fs = self.sims_base[id] | |||
| else: | |||
| fs = self.sims_senior[id - num_base] | |||
| return fs | |||
| class Worker(threading.Thread): | |||
| """ | |||
| id: worker id | |||
| dataset: the dataset | |||
| pool: queue as the global data buffer | |||
| """ | |||
| def __init__(self, id, dataset, pool): | |||
| threading.Thread.__init__(self) | |||
| self.id = id | |||
| self.dataset = dataset | |||
| self.pool = pool | |||
| self.isrun = True | |||
| self.nn = 0 | |||
| def run(self): | |||
| while self.isrun: | |||
| self.nn += 1 | |||
| logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1') | |||
| # get simulated minibatch | |||
| if self.isrun: | |||
| data = self.dataset.getBatch(self.id) | |||
| logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2') | |||
| # put data into buffer | |||
| if self.isrun: | |||
| self.pool.put(data) | |||
| logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3') | |||
| logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id)) | |||
| def stopWorker(self): | |||
| """ | |||
| stop the worker thread | |||
| """ | |||
| self.isrun = False | |||
| class KWSDataLoader: | |||
| """ | |||
| dataset: the dataset reference | |||
| batchsize: data batch size | |||
| numworkers: no. of workers | |||
| prefetch: prefetch factor | |||
| """ | |||
| def __init__(self, dataset, batchsize, numworkers, prefetch=2): | |||
| self.dataset = dataset | |||
| self.batchsize = batchsize | |||
| self.datamap = {} | |||
| self.isrun = True | |||
| # data queue | |||
| self.pool = queue.Queue(batchsize * prefetch) | |||
| # initialize workers | |||
| self.workerlist = [] | |||
| for id in range(numworkers): | |||
| w = Worker(id, dataset, self.pool) | |||
| self.workerlist.append(w) | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| while self.isrun: | |||
| # get data from common data pool | |||
| data = self.pool.get() | |||
| self.pool.task_done() | |||
| # group minibatches with the same shape | |||
| key = str(data[0].shape) | |||
| batchl = self.datamap.get(key) | |||
| if batchl is None: | |||
| batchl = [] | |||
| self.datamap.update({key: batchl}) | |||
| batchl.append(data) | |||
| # a full data batch collected | |||
| if len(batchl) >= self.batchsize: | |||
| featbatch = [] | |||
| labelbatch = [] | |||
| for feat, label in batchl: | |||
| featbatch.append(feat) | |||
| labelbatch.append(label) | |||
| batchl.clear() | |||
| feattensor = torch.stack(featbatch, dim=0) | |||
| labeltensor = torch.stack(labelbatch, dim=0) | |||
| if feattensor.shape[-2] == 1: | |||
| logger.debug('KWSDataLoader: Basetrain batch.') | |||
| else: | |||
| logger.debug('KWSDataLoader: Finetune batch.') | |||
| return feattensor, labeltensor | |||
| return None, None | |||
| def start(self): | |||
| """ | |||
| start multi-thread data loader | |||
| """ | |||
| for w in self.workerlist: | |||
| w.start() | |||
| def stop(self): | |||
| """ | |||
| stop data loader | |||
| """ | |||
| logger.info('KWSDataLoader: Stopping...') | |||
| self.isrun = False | |||
| for w in self.workerlist: | |||
| w.stopWorker() | |||
| while not self.pool.empty(): | |||
| self.pool.get(block=True, timeout=0.001) | |||
| # wait workers terminated | |||
| for w in self.workerlist: | |||
| while not self.pool.empty(): | |||
| self.pool.get(block=True, timeout=0.001) | |||
| w.join() | |||
| logger.info('KWSDataLoader: All worker stopped.') | |||
| @@ -0,0 +1,279 @@ | |||
| import datetime | |||
| import math | |||
| import os | |||
| from typing import Callable, Dict, Optional | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn as nn | |||
| from torch import optim as optim | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models import Model, TorchModel | |||
| from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset | |||
| from modelscope.trainers.base import BaseTrainer | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.utils.audio.audio_utils import update_conf | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
| from modelscope.utils.data_utils import to_device | |||
| from modelscope.utils.device import create_device | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, | |||
| init_dist, is_master) | |||
| logger = get_logger() | |||
| BASETRAIN_CONF_EASY = 'basetrain_easy' | |||
| BASETRAIN_CONF_NORMAL = 'basetrain_normal' | |||
| BASETRAIN_CONF_HARD = 'basetrain_hard' | |||
| FINETUNE_CONF_EASY = 'finetune_easy' | |||
| FINETUNE_CONF_NORMAL = 'finetune_normal' | |||
| FINETUNE_CONF_HARD = 'finetune_hard' | |||
| EASY_RATIO = 0.1 | |||
| NORMAL_RATIO = 0.6 | |||
| HARD_RATIO = 0.3 | |||
| BASETRAIN_RATIO = 0.5 | |||
| @TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield) | |||
| class KWSFarfieldTrainer(BaseTrainer): | |||
| DEFAULT_WORK_DIR = './work_dir' | |||
| conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY, | |||
| BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL, | |||
| BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD) | |||
| def __init__(self, | |||
| model: str, | |||
| work_dir: str, | |||
| cfg_file: Optional[str] = None, | |||
| arg_parse_fn: Optional[Callable] = None, | |||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
| custom_conf: Optional[dict] = None, | |||
| **kwargs): | |||
| if isinstance(model, str): | |||
| if os.path.exists(model): | |||
| self.model_dir = model if os.path.isdir( | |||
| model) else os.path.dirname(model) | |||
| else: | |||
| self.model_dir = snapshot_download( | |||
| model, revision=model_revision) | |||
| if cfg_file is None: | |||
| cfg_file = os.path.join(self.model_dir, | |||
| ModelFile.CONFIGURATION) | |||
| else: | |||
| assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' | |||
| self.model_dir = os.path.dirname(cfg_file) | |||
| super().__init__(cfg_file, arg_parse_fn) | |||
| self.model = self.build_model() | |||
| self.work_dir = work_dir | |||
| # the number of model output dimension | |||
| # should update config outside the trainer, if user need more wake word | |||
| self._num_classes = self.cfg.model.num_syn | |||
| if kwargs.get('launcher', None) is not None: | |||
| init_dist(kwargs['launcher']) | |||
| _, world_size = get_dist_info() | |||
| self._dist = world_size > 1 | |||
| device_name = kwargs.get('device', 'gpu') | |||
| if self._dist: | |||
| local_rank = get_local_rank() | |||
| device_name = f'cuda:{local_rank}' | |||
| self.device = create_device(device_name) | |||
| # model placement | |||
| if self.device.type == 'cuda': | |||
| self.model.to(self.device) | |||
| if 'max_epochs' not in kwargs: | |||
| assert hasattr( | |||
| self.cfg.train, 'max_epochs' | |||
| ), 'max_epochs is missing from the configuration file' | |||
| self._max_epochs = self.cfg.train.max_epochs | |||
| else: | |||
| self._max_epochs = kwargs['max_epochs'] | |||
| self._train_iters = kwargs.get('train_iters_per_epoch', None) | |||
| self._val_iters = kwargs.get('val_iters_per_epoch', None) | |||
| if self._train_iters is None: | |||
| self._train_iters = self.cfg.train.train_iters_per_epoch | |||
| if self._val_iters is None: | |||
| self._val_iters = self.cfg.evaluation.val_iters_per_epoch | |||
| dataloader_config = self.cfg.train.dataloader | |||
| self._threads = kwargs.get('workers', None) | |||
| if self._threads is None: | |||
| self._threads = dataloader_config.workers_per_gpu | |||
| self._single_rate = BASETRAIN_RATIO | |||
| if 'single_rate' in kwargs: | |||
| self._single_rate = kwargs['single_rate'] | |||
| self._batch_size = dataloader_config.batch_size_per_gpu | |||
| if 'model_bin' in kwargs: | |||
| model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | |||
| checkpoint = torch.load(model_bin_file) | |||
| self.model.load_state_dict(checkpoint) | |||
| # build corresponding optimizer and loss function | |||
| lr = self.cfg.train.optimizer.lr | |||
| self.optimizer = optim.Adam(self.model.parameters(), lr) | |||
| self.loss_fn = nn.CrossEntropyLoss() | |||
| self.data_val = None | |||
| self.json_log_path = os.path.join(self.work_dir, | |||
| '{}.log.json'.format(self.timestamp)) | |||
| self.conf_files = [] | |||
| for conf_key in self.conf_keys: | |||
| template_file = os.path.join(self.model_dir, conf_key) | |||
| conf_file = os.path.join(self.model_dir, f'{conf_key}.conf') | |||
| update_conf(template_file, conf_file, custom_conf[conf_key]) | |||
| self.conf_files.append(conf_file) | |||
| self._current_epoch = 0 | |||
| self.stages = (math.floor(self._max_epochs * EASY_RATIO), | |||
| math.floor(self._max_epochs * NORMAL_RATIO), | |||
| math.floor(self._max_epochs * HARD_RATIO)) | |||
| def build_model(self) -> nn.Module: | |||
| """ Instantiate a pytorch model and return. | |||
| By default, we will create a model using config from configuration file. You can | |||
| override this method in a subclass. | |||
| """ | |||
| model = Model.from_pretrained( | |||
| self.model_dir, cfg_dict=self.cfg, training=True) | |||
| if isinstance(model, TorchModel) and hasattr(model, 'model'): | |||
| return model.model | |||
| elif isinstance(model, nn.Module): | |||
| return model | |||
| def train(self, *args, **kwargs): | |||
| if not self.data_val: | |||
| self.gen_val() | |||
| logger.info('Start training...') | |||
| totaltime = datetime.datetime.now() | |||
| for stage, num_epoch in enumerate(self.stages): | |||
| self.run_stage(stage, num_epoch) | |||
| # total time spent | |||
| totaltime = datetime.datetime.now() - totaltime | |||
| logger.info('Total time spent: {:.2f} hours\n'.format( | |||
| totaltime.total_seconds() / 3600.0)) | |||
| def run_stage(self, stage, num_epoch): | |||
| """ | |||
| Run training stages with correspond data | |||
| Args: | |||
| stage: id of stage | |||
| num_epoch: the number of epoch to run in this stage | |||
| """ | |||
| if num_epoch <= 0: | |||
| logger.warning(f'Invalid epoch number, stage {stage} exit!') | |||
| return | |||
| logger.info(f'Starting stage {stage}...') | |||
| dataset, dataloader = self.create_dataloader( | |||
| self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) | |||
| it = iter(dataloader) | |||
| for _ in range(num_epoch): | |||
| self._current_epoch += 1 | |||
| epochtime = datetime.datetime.now() | |||
| logger.info('Start epoch %d...', self._current_epoch) | |||
| loss_train_epoch = 0.0 | |||
| validbatchs = 0 | |||
| for bi in range(self._train_iters): | |||
| # prepare data | |||
| feat, label = next(it) | |||
| label = torch.reshape(label, (-1, )) | |||
| feat = to_device(feat, self.device) | |||
| label = to_device(label, self.device) | |||
| # apply model | |||
| self.optimizer.zero_grad() | |||
| predict = self.model(feat) | |||
| # calculate loss | |||
| loss = self.loss_fn( | |||
| torch.reshape(predict, (-1, self._num_classes)), label) | |||
| if not np.isnan(loss.item()): | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| loss_train_epoch += loss.item() | |||
| validbatchs += 1 | |||
| train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format( | |||
| self._current_epoch, self._max_epochs, bi + 1, | |||
| self._train_iters, loss.item()) | |||
| logger.info(train_result) | |||
| self._dump_log(train_result) | |||
| # average training loss in one epoch | |||
| loss_train_epoch /= validbatchs | |||
| loss_val_epoch = self.evaluate('') | |||
| val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format( | |||
| self._current_epoch, loss_train_epoch, loss_val_epoch) | |||
| logger.info(val_result) | |||
| self._dump_log(val_result) | |||
| # check point | |||
| ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | |||
| self._current_epoch, loss_train_epoch, loss_val_epoch) | |||
| torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) | |||
| # time spent per epoch | |||
| epochtime = datetime.datetime.now() - epochtime | |||
| logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | |||
| self._current_epoch, | |||
| epochtime.total_seconds() / 3600.0)) | |||
| dataloader.stop() | |||
| dataset.release() | |||
| logger.info(f'Stage {stage} is finished.') | |||
| def gen_val(self): | |||
| """ | |||
| generate validation set | |||
| """ | |||
| logger.info('Start generating validation set...') | |||
| dataset, dataloader = self.create_dataloader(self.conf_files[2], | |||
| self.conf_files[3]) | |||
| it = iter(dataloader) | |||
| self.data_val = [] | |||
| for bi in range(self._val_iters): | |||
| logger.info('Iterating validation data %d', bi) | |||
| feat, label = next(it) | |||
| label = torch.reshape(label, (-1, )) | |||
| self.data_val.append([feat, label]) | |||
| dataloader.stop() | |||
| dataset.release() | |||
| logger.info('Finish generating validation set!') | |||
| def create_dataloader(self, base_path, finetune_path): | |||
| dataset = KWSDataset(base_path, finetune_path, self._threads, | |||
| self._single_rate, self._num_classes) | |||
| dataloader = KWSDataLoader( | |||
| dataset, batchsize=self._batch_size, numworkers=self._threads) | |||
| dataloader.start() | |||
| return dataset, dataloader | |||
| def evaluate(self, checkpoint_path: str, *args, | |||
| **kwargs) -> Dict[str, float]: | |||
| logger.info('Start validation...') | |||
| loss_val_epoch = 0.0 | |||
| with torch.no_grad(): | |||
| for feat, label in self.data_val: | |||
| feat = to_device(feat, self.device) | |||
| label = to_device(label, self.device) | |||
| # apply model | |||
| predict = self.model(feat) | |||
| # calculate loss | |||
| loss = self.loss_fn( | |||
| torch.reshape(predict, (-1, self._num_classes)), label) | |||
| loss_val_epoch += loss.item() | |||
| logger.info('Finish validation.') | |||
| return loss_val_epoch / self._val_iters | |||
| def _dump_log(self, msg): | |||
| if is_master(): | |||
| with open(self.json_log_path, 'a+') as f: | |||
| f.write(msg) | |||
| f.write('\n') | |||
| @@ -1,4 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import re | |||
| import struct | |||
| from typing import Union | |||
| from urllib.parse import urlparse | |||
| @@ -37,6 +38,23 @@ def audio_norm(x): | |||
| return x | |||
| def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||
| def repl(matched): | |||
| key = matched.group(1) | |||
| if key in conf_item: | |||
| return conf_item[key] | |||
| else: | |||
| return None | |||
| with open(origin_config_file) as f: | |||
| lines = f.readlines() | |||
| with open(new_config_file, 'w') as f: | |||
| for line in lines: | |||
| line = re.sub(r'\$\{(.*)\}', repl, line) | |||
| f.write(line) | |||
| def extract_pcm_from_wav(wav: bytes) -> bytes: | |||
| data = wav | |||
| if len(data) > 44: | |||
| @@ -14,7 +14,11 @@ nltk | |||
| numpy<=1.18 | |||
| # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | |||
| protobuf>3,<3.21.0 | |||
| py_sound_connect | |||
| ptflops | |||
| py_sound_connect>=0.1 | |||
| pytorch_wavelets | |||
| PyWavelets>=1.0.0 | |||
| scikit-learn | |||
| SoundFile>0.10 | |||
| sox | |||
| torchaudio | |||
| @@ -0,0 +1,85 @@ | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.test_utils import test_level | |||
| POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav' | |||
| NEG_FILE = 'data/test/audios/speech_with_noise.wav' | |||
| NOISE_FILE = 'data/test/audios/speech_with_noise.wav' | |||
| INTERF_FILE = 'data/test/audios/speech_with_noise.wav' | |||
| REF_FILE = 'data/test/audios/farend_speech.wav' | |||
| NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav' | |||
| class TestKwsFarfieldTrainer(unittest.TestCase): | |||
| REVISION = 'beta' | |||
| def setUp(self): | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| print(f'tmp dir: {self.tmp_dir}') | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya' | |||
| train_pos_list = self.create_list('pos.list', POS_FILE) | |||
| train_neg_list = self.create_list('neg.list', NEG_FILE) | |||
| train_noise1_list = self.create_list('noise.list', NOISE_FILE) | |||
| train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE) | |||
| train_interf_list = self.create_list('interf.list', INTERF_FILE) | |||
| train_ref_list = self.create_list('ref.list', REF_FILE) | |||
| base_dict = dict( | |||
| train_pos_list=train_pos_list, | |||
| train_neg_list=train_neg_list, | |||
| train_noise1_list=train_noise1_list) | |||
| fintune_dict = dict( | |||
| train_pos_list=train_pos_list, | |||
| train_neg_list=train_neg_list, | |||
| train_noise1_list=train_noise1_list, | |||
| train_noise2_type='1', | |||
| train_noise1_ratio='0.2', | |||
| train_noise2_list=train_noise2_list, | |||
| train_interf_list=train_interf_list, | |||
| train_ref_list=train_ref_list) | |||
| self.custom_conf = dict( | |||
| basetrain_easy=base_dict, | |||
| basetrain_normal=base_dict, | |||
| basetrain_hard=base_dict, | |||
| finetune_easy=fintune_dict, | |||
| finetune_normal=fintune_dict, | |||
| finetune_hard=fintune_dict) | |||
| def create_list(self, list_name, audio_file): | |||
| pos_list_file = os.path.join(self.tmp_dir, list_name) | |||
| with open(pos_list_file, 'w') as f: | |||
| for i in range(10): | |||
| f.write(f'{os.path.join(os.getcwd(), audio_file)}\n') | |||
| train_pos_list = f'{pos_list_file}, 1.0' | |||
| return train_pos_list | |||
| def tearDown(self) -> None: | |||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_normal(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| work_dir=self.tmp_dir, | |||
| model_revision=self.REVISION, | |||
| workers=2, | |||
| max_epochs=2, | |||
| train_iters_per_epoch=2, | |||
| val_iters_per_epoch=1, | |||
| custom_conf=self.custom_conf) | |||
| trainer = build_trainer( | |||
| Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files, | |||
| f'work_dir:{self.tmp_dir}') | |||