Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10743508 * fix: load model directly from .pthmaster
| @@ -54,7 +54,8 @@ class FSMNSeleNetV2Decorator(TorchModel): | |||||
| ) | ) | ||||
| def __del__(self): | def __del__(self): | ||||
| self.tmp_dir.cleanup() | |||||
| if hasattr(self, 'tmp_dir'): | |||||
| self.tmp_dir.cleanup() | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| return self.model.forward(input) | return self.model.forward(input) | ||||
| @@ -188,11 +188,13 @@ class Worker(threading.Thread): | |||||
| class KWSDataLoader: | class KWSDataLoader: | ||||
| """ | |||||
| dataset: the dataset reference | |||||
| batchsize: data batch size | |||||
| numworkers: no. of workers | |||||
| prefetch: prefetch factor | |||||
| """ Load and organize audio data with multiple threads | |||||
| Args: | |||||
| dataset: the dataset reference | |||||
| batchsize: data batch size | |||||
| numworkers: no. of workers | |||||
| prefetch: prefetch factor | |||||
| """ | """ | ||||
| def __init__(self, dataset, batchsize, numworkers, prefetch=2): | def __init__(self, dataset, batchsize, numworkers, prefetch=2): | ||||
| @@ -202,7 +204,7 @@ class KWSDataLoader: | |||||
| self.isrun = True | self.isrun = True | ||||
| # data queue | # data queue | ||||
| self.pool = queue.Queue(batchsize * prefetch) | |||||
| self.pool = queue.Queue(numworkers * prefetch) | |||||
| # initialize workers | # initialize workers | ||||
| self.workerlist = [] | self.workerlist = [] | ||||
| @@ -270,11 +272,11 @@ class KWSDataLoader: | |||||
| w.stopWorker() | w.stopWorker() | ||||
| while not self.pool.empty(): | while not self.pool.empty(): | ||||
| self.pool.get(block=True, timeout=0.001) | |||||
| self.pool.get(block=True, timeout=0.01) | |||||
| # wait workers terminated | # wait workers terminated | ||||
| for w in self.workerlist: | for w in self.workerlist: | ||||
| while not self.pool.empty(): | while not self.pool.empty(): | ||||
| self.pool.get(block=True, timeout=0.001) | |||||
| self.pool.get(block=True, timeout=0.01) | |||||
| w.join() | w.join() | ||||
| logger.info('KWSDataLoader: All worker stopped.') | logger.info('KWSDataLoader: All worker stopped.') | ||||
| @@ -117,8 +117,7 @@ class KWSFarfieldTrainer(BaseTrainer): | |||||
| self._batch_size = dataloader_config.batch_size_per_gpu | self._batch_size = dataloader_config.batch_size_per_gpu | ||||
| if 'model_bin' in kwargs: | if 'model_bin' in kwargs: | ||||
| model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) | ||||
| checkpoint = torch.load(model_bin_file) | |||||
| self.model.load_state_dict(checkpoint) | |||||
| self.model = torch.load(model_bin_file) | |||||
| # build corresponding optimizer and loss function | # build corresponding optimizer and loss function | ||||
| lr = self.cfg.train.optimizer.lr | lr = self.cfg.train.optimizer.lr | ||||
| self.optimizer = optim.Adam(self.model.parameters(), lr) | self.optimizer = optim.Adam(self.model.parameters(), lr) | ||||
| @@ -219,7 +218,9 @@ class KWSFarfieldTrainer(BaseTrainer): | |||||
| # check point | # check point | ||||
| ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( | ||||
| self._current_epoch, loss_train_epoch, loss_val_epoch) | self._current_epoch, loss_train_epoch, loss_val_epoch) | ||||
| torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) | |||||
| save_path = os.path.join(self.work_dir, ckpt_name) | |||||
| logger.info(f'Save model to {save_path}') | |||||
| torch.save(self.model, save_path) | |||||
| # time spent per epoch | # time spent per epoch | ||||
| epochtime = datetime.datetime.now() - epochtime | epochtime = datetime.datetime.now() - epochtime | ||||
| logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( | ||||
| @@ -43,7 +43,10 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): | |||||
| def repl(matched): | def repl(matched): | ||||
| key = matched.group(1) | key = matched.group(1) | ||||
| if key in conf_item: | if key in conf_item: | ||||
| return conf_item[key] | |||||
| value = conf_item[key] | |||||
| if not isinstance(value, str): | |||||
| value = str(value) | |||||
| return value | |||||
| else: | else: | ||||
| return None | return None | ||||