From: @wanyiming Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -76,6 +76,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework, | |||||
| - [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md) | - [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md) | ||||
| - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) | - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) | ||||
| - [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md) | - [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md) | ||||
| - [DeepSpeech2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech2/README.md) | |||||
| - [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc) | - [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc) | ||||
| - [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md) | - [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md) | ||||
| - [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md) | - [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md) | ||||
| @@ -0,0 +1,262 @@ | |||||
| # Contents | |||||
| - [DeepSpeech2 Description](#CenterNet-description) | |||||
| - [Model Architecture](#model-architecture) | |||||
| - [Dataset](#dataset) | |||||
| - [Environment Requirements](#environment-requirements) | |||||
| - [Script Description](#script-description) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Script Parameters](#script-parameters) | |||||
| - [Training and eval Process](#training-process) | |||||
| - [Export MindIR](#convert-process) | |||||
| - [Convert](#convert) | |||||
| - [Model Description](#model-description) | |||||
| - [Performance](#performance) | |||||
| - [Training Performance](#training-performance) | |||||
| - [Inference Performance](#inference-performance) | |||||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||||
| # [DeepSpeech2 Description](#contents) | |||||
| DeepSpeech2 is a speech recognition models which is trained with CTC loss. It replaces entire pipelines of hand-engineered components with neural networks and can handle a diverse variety of speech including noisy | |||||
| environments, accents and different languages. We support training and evaluation on GPU. | |||||
| [Paper](https://arxiv.org/pdf/1512.02595v1.pdf): Amodei, Dario, et al. Deep speech 2: End-to-end speech recognition in english and mandarin. | |||||
| # [Model Architecture](#contents) | |||||
| The current reproduced model consists of: | |||||
| - two convolutional layers: | |||||
| - number of channels is 32, kernel size is [41, 11], stride is [2, 2] | |||||
| - number of channels is 32, kernel size is [41, 11], stride is [2, 1] | |||||
| - five bidirectional LSTM layers (size is 1024) | |||||
| - one projection layer (size is number of characters plus 1 for CTC blank symbol, 29) | |||||
| # [Dataset](#contents) | |||||
| Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. | |||||
| Dataset used: [LibriSpeech](<http://www.openslr.org/12>) | |||||
| - Train Data: | |||||
| - train-clean-100: [6.3G] (training set of 100 hours "clean" speech) | |||||
| - train-clean-360.tar.gz [23G] (training set of 360 hours "clean" speech) | |||||
| - train-other-500.tar.gz [30G] (training set of 500 hours "other" speech) | |||||
| - Val Data: | |||||
| - dev-clean.tar.gz [337M] (development set, "clean" speech) | |||||
| - dev-other.tar.gz [314M] (development set, "other", more challenging, speech) | |||||
| - Test Data: | |||||
| - test-clean.tar.gz [346M] (test set, "clean" speech ) | |||||
| - test-other.tar.gz [328M] (test set, "other" speech ) | |||||
| - Data format:wav and txt files | |||||
| - Note:Data will be processed in librispeech.py | |||||
| # [Environment Requirements](#contents) | |||||
| - Hardware(GPU) | |||||
| - Prepare hardware environment with GPU processor. | |||||
| - Framework | |||||
| - [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) | |||||
| - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) | |||||
| # [Script Description](#contents) | |||||
| ## [Script and Sample Code](#contents) | |||||
| ```path | |||||
| . | |||||
| ├── audio | |||||
| ├── deepspeech2 | |||||
| ├── train.py // training scripts | |||||
| ├── eval.py // testing and evaluation outputs | |||||
| ├── export.py // convert mindspore model to mindir model | |||||
| ├── labels.json // possible characters to map to | |||||
| ├── README.md // descriptions about DeepSpeech | |||||
| ├── deepspeech_pytorch // | |||||
| ├──decoder.py // decoder from third party codes(MIT License) | |||||
| ├── src | |||||
| ├──__init__.py | |||||
| ├──DeepSpeech.py // DeepSpeech networks | |||||
| ├──dataset.py // generate dataloader and data processing entry | |||||
| ├──config.py // DeepSpeech configs | |||||
| ├──lr_generator.py // learning rate generator | |||||
| ├──greedydecoder.py // modified greedydecoder for mindspore code | |||||
| └──callback.py // callbacks to monitor the training | |||||
| ``` | |||||
| ## [Script Parameters](#contents) | |||||
| ### Training | |||||
| ```text | |||||
| usage: train.py [--use_pretrained USE_PRETRAINED] | |||||
| [--pre_trained_model_path PRE_TRAINED_MODEL_PATH] | |||||
| [--is_distributed IS_DISTRIBUTED] | |||||
| [--bidirectional BIDIRECTIONAL] | |||||
| options: | |||||
| --pre_trained_model_path pretrained checkpoint path, default is '' | |||||
| --is_distributed distributed training, default is False | |||||
| --bidirectional whether or not to use bidirectional RNN, default is True. Currently, only bidirectional model is implemented | |||||
| ``` | |||||
| ### Evaluation | |||||
| ```text | |||||
| usage: eval.py [--bidirectional BIDIRECTIONAL] | |||||
| [--pretrain_ckpt PRETRAIN_CKPT] | |||||
| options: | |||||
| --bidirectional whether to use bidirectional RNN, default is True. Currently, only bidirectional model is implemented | |||||
| --pretrain_ckpt saved checkpoint path, default is '' | |||||
| ``` | |||||
| ### Options and Parameters | |||||
| Parameters for training and evaluation can be set in file `config.py` | |||||
| ```text | |||||
| config for training. | |||||
| epochs number of training epoch, default is 70 | |||||
| ``` | |||||
| ```text | |||||
| config for dataloader. | |||||
| train_manifest train manifest path, default is 'data/libri_train_manifest.csv' | |||||
| val_manifest dev manifest path, default is 'data/libri_val_manifest.csv' | |||||
| batch_size batch size for training, default is 8 | |||||
| labels_path tokens json path for model output, default is "./labels.json" | |||||
| sample_rate sample rate for the data/model features, default is 16000 | |||||
| window_size window size for spectrogram generation (seconds), default is 0.02 | |||||
| window_stride window stride for spectrogram generation (seconds), default is 0.01 | |||||
| window window type for spectrogram generation, default is 'hamming' | |||||
| speed_volume_perturb use random tempo and gain perturbations, default is False, not used in current model | |||||
| spec_augment use simple spectral augmentation on mel spectograms, default is False, not used in current model | |||||
| noise_dir directory to inject noise into audio. If default, noise Inject not added, default is '', not used in current model | |||||
| noise_prob probability of noise being added per sample, default is 0.4, not used in current model | |||||
| noise_min minimum noise level to sample from. (1.0 means all noise, not original signal), default is 0.0, not used in current model | |||||
| noise_max maximum noise levels to sample from. Maximum 1.0, default is 0.5, not used in current model | |||||
| ``` | |||||
| ```text | |||||
| config for model. | |||||
| rnn_type type of RNN to use in model, default is 'LSTM'. Currently, only LSTM is supported | |||||
| hidden_size hidden size of RNN Layer, default is 1024 | |||||
| hidden_layers number of RNN layers, default is 5 | |||||
| lookahead_context look ahead context, default is 20, not used in current model | |||||
| ``` | |||||
| ```text | |||||
| config for optimizer. | |||||
| learning_rate initial learning rate, default is 3e-4 | |||||
| learning_anneal annealing applied to learning rate after each epoch, default is 1.1 | |||||
| weight_decay weight decay, default is 1e-5 | |||||
| momentum momentum, default is 0.9 | |||||
| eps Adam eps, default is 1e-8 | |||||
| betas Adam betas, default is (0.9, 0.999) | |||||
| loss_scale loss scale, default is 1024 | |||||
| ``` | |||||
| ```text | |||||
| config for checkpoint. | |||||
| ckpt_file_name_prefix ckpt_file_name_prefix, default is 'DeepSpeech' | |||||
| ckpt_path path to save ckpt, default is 'checkpoints' | |||||
| keep_checkpoint_max max number of checkpoints to save, delete older checkpoints, default is 10 | |||||
| ``` | |||||
| # [Training and Eval process](#contents) | |||||
| Before training, the dataset should be processed. We use the scripts provided by [SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch) to process the dataset. | |||||
| This script in [SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch) will automatically download the dataset and process it. After the process, the | |||||
| dataset directory structure is as follows: | |||||
| ```path | |||||
| . | |||||
| ├─ LibriSpeech_dataset | |||||
| │ ├── train | |||||
| │ │ ├─ wav | |||||
| │ │ └─ txt | |||||
| │ ├── val | |||||
| │ │ ├─ wav | |||||
| │ │ └─ txt | |||||
| │ ├── test_clean | |||||
| │ │ ├─ wav | |||||
| │ │ └─ txt | |||||
| │ └── test_other | |||||
| │ ├─ wav | |||||
| │ └─ txt | |||||
| └─ libri_test_clean_manifest.csv, libri_test_other_manifest.csv, libri_train_manifest.csv, libri_val_manifest.csv | |||||
| ``` | |||||
| The three *.csv file stores the absolute path of the corresponding | |||||
| data. The three *.csv files will be used in training and evaluation process. | |||||
| After installing MindSpore via the official website and finishing dataset processing, you can start training as follows: | |||||
| ```shell | |||||
| # standalone training | |||||
| CUDA_VISIBLE_DEVICES='0' python train.py | |||||
| # distributed training | |||||
| CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' mpirun --allow-run-as-root -n 8 python train.py --is_distributed=True > log 2>&1 & | |||||
| ``` | |||||
| The following script is used to evaluate the model. Note we only support greedy decoder now and before run the script, | |||||
| you should download the decoder code from [SeanNaren](https://github.com/SeanNaren/deepspeech.pytorch) and place | |||||
| deepspeech_pytorch into deepspeech2 directory. After that, the file directory will be displayed as that in [Script and Sample Code] | |||||
| ```shell | |||||
| # eval | |||||
| CUDA_VISIBLE_DEVICES='0' python eval.py --pretrain_ckpt='saved_model_path' | |||||
| ``` | |||||
| ## [Export MindIR](#contents) | |||||
| ```bash | |||||
| python export.py --pre_trained_model_path='ckpt_path' | |||||
| ``` | |||||
| # [Model Description](#contents) | |||||
| ## [Performance](#contents) | |||||
| ### Training Performance | |||||
| | Parameters | DeepSpeech | | |||||
| | -------------------------- | ---------------------------------------------------------------| | |||||
| | Resource | NV SMX2 V100-32G | | |||||
| | uploaded Date | 12/29/2020 (month/day/year) | | |||||
| | MindSpore Version | 1.0.0 | | |||||
| | Dataset | LibriSpeech | | |||||
| | Training Parameters | 2p, epoch=70, steps=5144 * epoch, batch_size = 20, lr=3e-4 | | |||||
| | Optimizer | Adam | | |||||
| | Loss Function | CTCLoss | | |||||
| | outputs | probability | | |||||
| | Loss | 0.2-0.7 | | |||||
| | Speed | 2p 2.139s/step | | |||||
| | Total time: training | 2p: around 1 week; | | |||||
| | Checkpoint | 991M (.ckpt file) | | |||||
| | Scripts | [DeepSpeech script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/deepspeech) | | |||||
| ### Inference Performance | |||||
| | Parameters | DeepSpeech | | |||||
| | -------------------------- | ----------------------------------------------------------------| | |||||
| | Resource | NV SMX2 V100-32G | | |||||
| | uploaded Date | 12/29/2020 (month/day/year) | | |||||
| | MindSpore Version | 1.0.0 | | |||||
| | Dataset | LibriSpeech | | |||||
| | batch_size | 20 | | |||||
| | outputs | probability | | |||||
| | Accuracy(test-clean) | WER: 9.732 CER: 3.270| | |||||
| | Accuracy(test-others) | WER: 28.198 CER: 12.253| | |||||
| | Model for inference | 330M (.mindir file) | | |||||
| # [ModelZoo Homepage](#contents) | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -0,0 +1,112 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """ | |||||
| Eval DeepSpeech2 | |||||
| """ | |||||
| import argparse | |||||
| import json | |||||
| import pickle | |||||
| import numpy as np | |||||
| from src.config import eval_config | |||||
| from src.deepspeech2 import DeepSpeechModel, PredictWithSoftmax | |||||
| from src.dataset import create_dataset | |||||
| from src.greedydecoder import MSGreedyDecoder | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False) | |||||
| parser = argparse.ArgumentParser(description='DeepSpeech evaluation') | |||||
| parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN') | |||||
| parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path') | |||||
| args = parser.parse_args() | |||||
| if __name__ == '__main__': | |||||
| config = eval_config | |||||
| with open(config.DataConfig.labels_path) as label_file: | |||||
| labels = json.load(label_file) | |||||
| model = PredictWithSoftmax(DeepSpeechModel(batch_size=config.DataConfig.batch_size, | |||||
| rnn_hidden_size=config.ModelConfig.hidden_size, | |||||
| nb_layers=config.ModelConfig.hidden_layers, | |||||
| labels=labels, | |||||
| rnn_type=config.ModelConfig.rnn_type, | |||||
| audio_conf=config.DataConfig.SpectConfig, | |||||
| bidirectional=args.bidirectional)) | |||||
| ds_eval = create_dataset(audio_conf=config.DataConfig.SpectConfig, | |||||
| manifest_filepath=config.DataConfig.test_manifest, | |||||
| labels=labels, normalize=True, train_mode=False, | |||||
| batch_size=config.DataConfig.batch_size, rank=0, group_size=1) | |||||
| param_dict = load_checkpoint(args.pretrain_ckpt) | |||||
| load_param_into_net(model, param_dict) | |||||
| print('Successfully loading the pre-trained model') | |||||
| if config.LMConfig.decoder_type == 'greedy': | |||||
| decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index('_')) | |||||
| else: | |||||
| raise NotImplementedError("Only greedy decoder is supported now") | |||||
| target_decoder = MSGreedyDecoder(labels, blank_index=labels.index('_')) | |||||
| model.set_train(False) | |||||
| total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0 | |||||
| output_data = [] | |||||
| for data in ds_eval.create_dict_iterator(): | |||||
| inputs, input_length, target_indices, targets = data['inputs'], data['input_length'], data['target_indices'], \ | |||||
| data['label_values'] | |||||
| split_targets = [] | |||||
| start, count, last_id = 0, 0, 0 | |||||
| target_indices, targets = target_indices.asnumpy(), targets.asnumpy() | |||||
| for i in range(np.shape(targets)[0]): | |||||
| if target_indices[i, 0] == last_id: | |||||
| count += 1 | |||||
| else: | |||||
| split_targets.append(list(targets[start:count])) | |||||
| last_id += 1 | |||||
| start = count | |||||
| count += 1 | |||||
| out, output_sizes = model(inputs, input_length) | |||||
| decoded_output, _ = decoder.decode(out, output_sizes) | |||||
| target_strings = target_decoder.convert_to_strings(split_targets) | |||||
| if config.save_output is not None: | |||||
| output_data.append((out.asnumpy(), output_sizes.asnumpy(), target_strings)) | |||||
| for doutput, toutput in zip(decoded_output, target_strings): | |||||
| transcript, reference = doutput[0], toutput[0] | |||||
| wer_inst = decoder.wer(transcript, reference) | |||||
| cer_inst = decoder.cer(transcript, reference) | |||||
| total_wer += wer_inst | |||||
| total_cer += cer_inst | |||||
| num_tokens += len(reference.split()) | |||||
| num_chars += len(reference.replace(' ', '')) | |||||
| if config.verbose: | |||||
| print("Ref:", reference.lower()) | |||||
| print("Hyp:", transcript.lower()) | |||||
| print("WER:", float(wer_inst) / len(reference.split()), | |||||
| "CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n") | |||||
| wer = float(total_wer) / num_tokens | |||||
| cer = float(total_cer) / num_chars | |||||
| print('Test Summary \t' | |||||
| 'Average WER {wer:.3f}\t' | |||||
| 'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100)) | |||||
| if config.save_output is not None: | |||||
| with open(config.save_output + '.bin', 'wb') as output: | |||||
| pickle.dump(output_data, output) | |||||
| @@ -0,0 +1,51 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| export checkpoint file to mindir model | |||||
| """ | |||||
| import json | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore import context, Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||||
| from src.deepspeech2 import DeepSpeechModel | |||||
| from src.config import train_config | |||||
| parser = argparse.ArgumentParser(description='Export DeepSpeech model to Mindir') | |||||
| parser.add_argument('--pre_trained_model_path', type=str, default='', help=' existed checkpoint path') | |||||
| args = parser.parse_args() | |||||
| if __name__ == '__main__': | |||||
| config = train_config | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False) | |||||
| with open(config.DataConfig.labels_path) as label_file: | |||||
| labels = json.load(label_file) | |||||
| deepspeech_net = DeepSpeechModel(batch_size=1, | |||||
| rnn_hidden_size=config.ModelConfig.hidden_size, | |||||
| nb_layers=config.ModelConfig.hidden_layers, | |||||
| labels=labels, | |||||
| rnn_type=config.ModelConfig.rnn_type, | |||||
| audio_conf=config.DataConfig.SpectConfig, | |||||
| bidirectional=True) | |||||
| param_dict = load_checkpoint(args.pre_trained_model_path) | |||||
| load_param_into_net(deepspeech_net, param_dict) | |||||
| print('Successfully loading the pre-trained model') | |||||
| # 3500 is the max length in evaluation dataset(LibriSpeech). This is consistent with that in dataset.py | |||||
| # The length is fixed to this value because Mindspore does not support dynamic shape currently | |||||
| input_np = np.random.uniform(0.0, 1.0, size=[1, 1, 161, 3500]).astype(np.float32) | |||||
| length = np.array([15], dtype=np.int32) | |||||
| export(deepspeech_net, Tensor(input_np), Tensor(length), file_name="deepspeech2.mindir", file_format='MINDIR') | |||||
| @@ -0,0 +1,31 @@ | |||||
| [ | |||||
| "'", | |||||
| "A", | |||||
| "B", | |||||
| "C", | |||||
| "D", | |||||
| "E", | |||||
| "F", | |||||
| "G", | |||||
| "H", | |||||
| "I", | |||||
| "J", | |||||
| "K", | |||||
| "L", | |||||
| "M", | |||||
| "N", | |||||
| "O", | |||||
| "P", | |||||
| "Q", | |||||
| "R", | |||||
| "S", | |||||
| "T", | |||||
| "U", | |||||
| "V", | |||||
| "W", | |||||
| "X", | |||||
| "Y", | |||||
| "Z", | |||||
| " ", | |||||
| "_" | |||||
| ] | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Defined callback for DeepSpeech. | |||||
| """ | |||||
| import time | |||||
| from mindspore.train.callback import Callback | |||||
| from mindspore import Tensor | |||||
| import numpy as np | |||||
| class TimeMonitor(Callback): | |||||
| """ | |||||
| Time monitor for calculating cost of each epoch. | |||||
| Args | |||||
| data_size (int) step size of an epoch. | |||||
| """ | |||||
| def __init__(self, data_size): | |||||
| super(TimeMonitor, self).__init__() | |||||
| self.data_size = data_size | |||||
| def epoch_begin(self, run_context): | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| per_step_mseconds = epoch_mseconds / self.data_size | |||||
| print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) | |||||
| def step_begin(self, run_context): | |||||
| self.step_time = time.time() | |||||
| def step_end(self, run_context): | |||||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||||
| print(f"step time {step_mseconds}", flush=True) | |||||
| class Monitor(Callback): | |||||
| """ | |||||
| Monitor loss and time. | |||||
| Args: | |||||
| lr_init (numpy array): train lr | |||||
| Returns: | |||||
| None | |||||
| """ | |||||
| def __init__(self, lr_init=None): | |||||
| super(Monitor, self).__init__() | |||||
| self.lr_init = lr_init | |||||
| self.lr_init_len = len(lr_init) | |||||
| def epoch_begin(self, run_context): | |||||
| self.losses = [] | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| epoch_mseconds = (time.time() - self.epoch_time) | |||||
| per_step_mseconds = epoch_mseconds / cb_params.batch_num | |||||
| print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, | |||||
| per_step_mseconds, | |||||
| np.mean(self.losses))) | |||||
| def step_begin(self, run_context): | |||||
| self.step_time = time.time() | |||||
| def step_end(self, run_context): | |||||
| """ | |||||
| Args: | |||||
| run_context: | |||||
| Returns: | |||||
| """ | |||||
| cb_params = run_context.original_args() | |||||
| step_mseconds = (time.time() - self.step_time) | |||||
| step_loss = cb_params.net_outputs | |||||
| if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): | |||||
| step_loss = step_loss[0] | |||||
| if isinstance(step_loss, Tensor): | |||||
| step_loss = np.mean(step_loss.asnumpy()) | |||||
| self.losses.append(step_loss) | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num | |||||
| print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:.9f}]".format( | |||||
| cb_params.cur_epoch_num - | |||||
| 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, | |||||
| np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1].asnumpy())) | |||||
| @@ -0,0 +1,113 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """ | |||||
| network config setting, will be used in train.py and eval.py | |||||
| """ | |||||
| from easydict import EasyDict as ed | |||||
| train_config = ed({ | |||||
| "TrainingConfig": { | |||||
| "epochs": 70, | |||||
| }, | |||||
| "DataConfig": { | |||||
| "train_manifest": 'data/libri_train_manifest.csv', | |||||
| # "val_manifest": 'data/libri_val_manifest.csv', | |||||
| "batch_size": 20, | |||||
| "labels_path": "labels.json", | |||||
| "SpectConfig": { | |||||
| "sample_rate": 16000, | |||||
| "window_size": 0.02, | |||||
| "window_stride": 0.01, | |||||
| "window": "hamming" | |||||
| }, | |||||
| "AugmentationConfig": { | |||||
| "speed_volume_perturb": False, | |||||
| "spec_augment": False, | |||||
| "noise_dir": '', | |||||
| "noise_prob": 0.4, | |||||
| "noise_min": 0.0, | |||||
| "noise_max": 0.5, | |||||
| } | |||||
| }, | |||||
| "ModelConfig": { | |||||
| "rnn_type": "LSTM", | |||||
| "hidden_size": 1024, | |||||
| "hidden_layers": 5, | |||||
| "lookahead_context": 20, | |||||
| }, | |||||
| "OptimConfig": { | |||||
| "learning_rate": 3e-4, | |||||
| "learning_anneal": 1.1, | |||||
| "weight_decay": 1e-5, | |||||
| "momentum": 0.9, | |||||
| "eps": 1e-8, | |||||
| "betas": (0.9, 0.999), | |||||
| "loss_scale": 1024, | |||||
| "epsilon": 0.00001 | |||||
| }, | |||||
| "CheckpointConfig": { | |||||
| "ckpt_file_name_prefix": 'DeepSpeech', | |||||
| "ckpt_path": './checkpoint', | |||||
| "keep_checkpoint_max": 10 | |||||
| } | |||||
| }) | |||||
| eval_config = ed({ | |||||
| "save_output": 'librispeech_val_output', | |||||
| "verbose": True, | |||||
| "DataConfig": { | |||||
| "test_manifest": 'data/libri_test_clean_manifest.csv', | |||||
| # "test_manifest": 'data/libri_test_other_manifest.csv', | |||||
| # "test_manifest": 'data/libri_val_manifest.csv', | |||||
| "batch_size": 20, | |||||
| "labels_path": "labels.json", | |||||
| "SpectConfig": { | |||||
| "sample_rate": 16000, | |||||
| "window_size": 0.02, | |||||
| "window_stride": 0.01, | |||||
| "window": "hanning" | |||||
| }, | |||||
| }, | |||||
| "ModelConfig": { | |||||
| "rnn_type": "LSTM", | |||||
| "hidden_size": 1024, | |||||
| "hidden_layers": 5, | |||||
| "lookahead_context": 20, | |||||
| }, | |||||
| "LMConfig": { | |||||
| "decoder_type": "greedy", | |||||
| "lm_path": './3-gram.pruned.3e-7.arpa', | |||||
| "top_paths": 1, | |||||
| "alpha": 1.818182, | |||||
| "beta": 0, | |||||
| "cutoff_top_n": 40, | |||||
| "cutoff_prob": 1.0, | |||||
| "beam_width": 1024, | |||||
| "lm_workers": 4 | |||||
| }, | |||||
| }) | |||||
| @@ -0,0 +1,215 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Create train or eval dataset. | |||||
| """ | |||||
| import math | |||||
| import numpy as np | |||||
| import mindspore.dataset.engine as de | |||||
| import librosa | |||||
| import soundfile as sf | |||||
| TRAIN_INPUT_PAD_LENGTH = 1501 | |||||
| TRAIN_LABEL_PAD_LENGTH = 350 | |||||
| TEST_INPUT_PAD_LENGTH = 3500 | |||||
| class LoadAudioAndTranscript(): | |||||
| """ | |||||
| parse audio and transcript | |||||
| """ | |||||
| def __init__(self, | |||||
| audio_conf=None, | |||||
| normalize=False, | |||||
| labels=None): | |||||
| super(LoadAudioAndTranscript, self).__init__() | |||||
| self.window_stride = audio_conf.window_stride | |||||
| self.window_size = audio_conf.window_size | |||||
| self.sample_rate = audio_conf.sample_rate | |||||
| self.window = audio_conf.window | |||||
| self.is_normalization = normalize | |||||
| self.labels = labels | |||||
| def load_audio(self, path): | |||||
| """ | |||||
| load audio | |||||
| """ | |||||
| sound, _ = sf.read(path, dtype='int16') | |||||
| sound = sound.astype('float32') / 32767 | |||||
| if len(sound.shape) > 1: | |||||
| if sound.shape[1] == 1: | |||||
| sound = sound.squeeze() | |||||
| else: | |||||
| sound = sound.mean(axis=1) | |||||
| return sound | |||||
| def parse_audio(self, audio_path): | |||||
| """ | |||||
| parse audio | |||||
| """ | |||||
| audio = self.load_audio(audio_path) | |||||
| n_fft = int(self.sample_rate * self.window_size) | |||||
| win_length = n_fft | |||||
| hop_length = int(self.sample_rate * self.window_stride) | |||||
| D = librosa.stft(y=audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=self.window) | |||||
| mag, _ = librosa.magphase(D) | |||||
| mag = np.log1p(mag) | |||||
| if self.is_normalization: | |||||
| mean = mag.mean() | |||||
| std = mag.std() | |||||
| mag = (mag - mean) / std | |||||
| return mag | |||||
| def parse_transcript(self, transcript_path): | |||||
| with open(transcript_path, 'r', encoding='utf8') as transcript_file: | |||||
| transcript = transcript_file.read().replace('\n', '') | |||||
| transcript = list(filter(None, [self.labels.get(x) for x in list(transcript)])) | |||||
| return transcript | |||||
| class ASRDataset(LoadAudioAndTranscript): | |||||
| """ | |||||
| create ASRDataset | |||||
| Args: | |||||
| audio_conf: Config containing the sample rate, window and the window length/stride in seconds | |||||
| manifest_filepath (str): manifest_file path. | |||||
| labels (list): List containing all the possible characters to map to | |||||
| normalize: Apply standard mean and deviation normalization to audio tensor | |||||
| batch_size (int): Dataset batch size (default=32) | |||||
| """ | |||||
| def __init__(self, audio_conf=None, | |||||
| manifest_filepath='', | |||||
| labels=None, | |||||
| normalize=False, | |||||
| batch_size=32, | |||||
| is_training=True): | |||||
| with open(manifest_filepath) as f: | |||||
| ids = f.readlines() | |||||
| ids = [x.strip().split(',') for x in ids] | |||||
| self.is_training = is_training | |||||
| self.ids = ids | |||||
| self.blank_id = int(labels.index('_')) | |||||
| self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)] | |||||
| if len(self.ids) % batch_size != 0: | |||||
| self.bins = self.bins[:-1] | |||||
| self.bins.append(ids[-batch_size:]) | |||||
| self.size = len(self.bins) | |||||
| self.batch_size = batch_size | |||||
| self.labels_map = {labels[i]: i for i in range(len(labels))} | |||||
| super(ASRDataset, self).__init__(audio_conf, normalize, self.labels_map) | |||||
| def __getitem__(self, index): | |||||
| batch_idx = self.bins[index] | |||||
| batch_size = len(batch_idx) | |||||
| batch_spect, batch_script, target_indices = [], [], [] | |||||
| input_length = np.zeros(batch_size, np.int32) | |||||
| for data in batch_idx: | |||||
| audio_path, transcript_path = data[0], data[1] | |||||
| spect = self.parse_audio(audio_path) | |||||
| transcript = self.parse_transcript(transcript_path) | |||||
| batch_spect.append(spect) | |||||
| batch_script.append(transcript) | |||||
| freq_size = np.shape(batch_spect[-1])[0] | |||||
| if self.is_training: | |||||
| # 1501 is the max length in train dataset(LibriSpeech). | |||||
| # The length is fixed to this value because Mindspore does not support dynamic shape currently | |||||
| inputs = np.zeros((batch_size, 1, freq_size, TRAIN_INPUT_PAD_LENGTH), dtype=np.float32) | |||||
| # The target length is fixed to this value because Mindspore does not support dynamic shape currently | |||||
| # 350 may be greater than the max length of labels in train dataset(LibriSpeech). | |||||
| targets = np.ones((self.batch_size, TRAIN_LABEL_PAD_LENGTH), dtype=np.int32) * self.blank_id | |||||
| for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script): | |||||
| seq_length = np.shape(spect_)[1] | |||||
| input_length[k] = seq_length | |||||
| script_length = len(scripts_) | |||||
| targets[k, :script_length] = scripts_ | |||||
| for m in range(350): | |||||
| target_indices.append([k, m]) | |||||
| inputs[k, 0, :, 0:seq_length] = spect_ | |||||
| targets = np.reshape(targets, (-1,)) | |||||
| else: | |||||
| inputs = np.zeros((batch_size, 1, freq_size, TEST_INPUT_PAD_LENGTH), dtype=np.float32) | |||||
| targets = [] | |||||
| for k, spect_, scripts_ in zip(range(batch_size), batch_spect, batch_script): | |||||
| seq_length = np.shape(spect_)[1] | |||||
| input_length[k] = seq_length | |||||
| targets.extend(scripts_) | |||||
| for m in range(len(scripts_)): | |||||
| target_indices.append([k, m]) | |||||
| inputs[k, 0, :, 0:seq_length] = spect_ | |||||
| return inputs, input_length, np.array(target_indices, dtype=np.int64), np.array(targets, dtype=np.int32) | |||||
| def __len__(self): | |||||
| return self.size | |||||
| class DistributedSampler(): | |||||
| """ | |||||
| function to distribute and shuffle sample | |||||
| """ | |||||
| def __init__(self, dataset, rank, group_size, shuffle=True, seed=0): | |||||
| self.dataset = dataset | |||||
| self.rank = rank | |||||
| self.group_size = group_size | |||||
| self.dataset_len = len(self.dataset) | |||||
| self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size)) | |||||
| self.total_size = self.num_samplers * self.group_size | |||||
| self.shuffle = shuffle | |||||
| self.seed = seed | |||||
| def __iter__(self): | |||||
| if self.shuffle: | |||||
| self.seed = (self.seed + 1) & 0xffffffff | |||||
| np.random.seed(self.seed) | |||||
| indices = np.random.permutation(self.dataset_len).tolist() | |||||
| else: | |||||
| indices = list(range(self.dataset_len)) | |||||
| indices += indices[:(self.total_size - len(indices))] | |||||
| indices = indices[self.rank::self.group_size] | |||||
| return iter(indices) | |||||
| def __len__(self): | |||||
| return self.num_samplers | |||||
| def create_dataset(audio_conf, manifest_filepath, labels, normalize, batch_size, train_mode=True, | |||||
| rank=None, group_size=None): | |||||
| """ | |||||
| create train dataset | |||||
| Args: | |||||
| audio_conf: Config containing the sample rate, window and the window length/stride in seconds | |||||
| manifest_filepath (str): manifest_file path. | |||||
| labels (list): list containing all the possible characters to map to | |||||
| normalize: Apply standard mean and deviation normalization to audio tensor | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| batch_size (int): Dataset batch size | |||||
| rank (int): The shard ID within num_shards (default=None). | |||||
| group_size (int): Number of shards that the dataset should be divided into (default=None). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| dataset = ASRDataset(audio_conf=audio_conf, manifest_filepath=manifest_filepath, labels=labels, normalize=normalize, | |||||
| batch_size=batch_size, is_training=train_mode) | |||||
| sampler = DistributedSampler(dataset, rank, group_size, shuffle=True) | |||||
| ds = de.GeneratorDataset(dataset, ["inputs", "input_length", "target_indices", "label_values"], sampler=sampler) | |||||
| ds = ds.repeat(1) | |||||
| return ds | |||||
| @@ -0,0 +1,300 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| DeepSpeech2 model | |||||
| """ | |||||
| import math | |||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import nn, Tensor, ParameterTuple, Parameter | |||||
| from mindspore.common.initializer import initializer | |||||
| class SequenceWise(nn.Cell): | |||||
| """ | |||||
| SequenceWise FC Layers. | |||||
| """ | |||||
| def __init__(self, module): | |||||
| super(SequenceWise, self).__init__() | |||||
| self.module = module | |||||
| self.reshape_op = P.Reshape() | |||||
| self.shape_op = P.Shape() | |||||
| self._initialize_weights() | |||||
| def construct(self, x): | |||||
| sizes = self.shape_op(x) | |||||
| t, n = sizes[0], sizes[1] | |||||
| x = self.reshape_op(x, (t * n, -1)) | |||||
| x = self.module(x) | |||||
| x = self.reshape_op(x, (t, n, -1)) | |||||
| return x | |||||
| def _initialize_weights(self): | |||||
| self.init_parameters_data() | |||||
| for _, m in self.cells_and_names(): | |||||
| if isinstance(m, nn.Dense): | |||||
| m.weight.set_data(Tensor( | |||||
| np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_data(Tensor( | |||||
| np.random.uniform(-1. / m.in_channels, 1. / m.in_channels, m.bias.data.shape).astype( | |||||
| "float32"))) | |||||
| class MaskConv(nn.Cell): | |||||
| """ | |||||
| MaskConv architecture. MaskConv is actually not implemented in this part because some operation in MindSpore | |||||
| is not supported. lengths is kept for future use. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(MaskConv, self).__init__() | |||||
| self.zeros = P.ZerosLike() | |||||
| self.conv1 = nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), pad_mode='pad', padding=(20, 20, 5, 5)) | |||||
| self.bn1 = nn.BatchNorm2d(num_features=32) | |||||
| self.conv2 = nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), pad_mode='pad', padding=(10, 10, 5, 5)) | |||||
| self.bn2 = nn.BatchNorm2d(num_features=32) | |||||
| self.tanh = nn.Tanh() | |||||
| self._initialize_weights() | |||||
| self.module_list = nn.CellList([self.conv1, self.bn1, self.tanh, self.conv2, self.bn2, self.tanh]) | |||||
| def construct(self, x, lengths): | |||||
| for module in self.module_list: | |||||
| x = module(x) | |||||
| return x | |||||
| def _initialize_weights(self): | |||||
| """ | |||||
| parameter initialization | |||||
| """ | |||||
| self.init_parameters_data() | |||||
| for _, m in self.cells_and_names(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
| m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if m.bias is not None: | |||||
| m.bias.set_data( | |||||
| Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) | |||||
| elif isinstance(m, nn.BatchNorm2d): | |||||
| m.gamma.set_data( | |||||
| Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) | |||||
| m.beta.set_data( | |||||
| Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) | |||||
| class BatchRNN(nn.Cell): | |||||
| """ | |||||
| BatchRNN architecture. | |||||
| Args: | |||||
| batch_size(int): smaple_number of per step in training | |||||
| input_size (int): dimension of input tensor | |||||
| hidden_size(int): rnn hidden size | |||||
| num_layers(int): rnn layers | |||||
| bidirectional(bool): use bidirectional rnn (default=True). Currently, only bidirectional rnn is implemented. | |||||
| batch_norm(bool): whether to use batchnorm in RNN. Currently, GPU does not support batch_norm1D (default=False). | |||||
| rnn_type (str): rnn type to use (default='LSTM'). Currently, only LSTM is supported. | |||||
| """ | |||||
| def __init__(self, batch_size, input_size, hidden_size, num_layers, bidirectional=False, batch_norm=False, | |||||
| rnn_type='LSTM'): | |||||
| super(BatchRNN, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.input_size = input_size | |||||
| self.hidden_size = hidden_size | |||||
| self.num_layers = num_layers | |||||
| self.rnn_type = rnn_type | |||||
| self.bidirectional = bidirectional | |||||
| self.has_bias = True | |||||
| self.is_batch_norm = batch_norm | |||||
| self.num_directions = 2 if bidirectional else 1 | |||||
| self.reshape_op = P.Reshape() | |||||
| self.shape_op = P.Shape() | |||||
| self.sum_op = P.ReduceSum() | |||||
| input_size_list = [input_size] | |||||
| for i in range(num_layers - 1): | |||||
| input_size_list.append(hidden_size) | |||||
| layers = [] | |||||
| for i in range(num_layers): | |||||
| layers.append( | |||||
| nn.LSTMCell(input_size=input_size_list[i], hidden_size=hidden_size, bidirectional=bidirectional, | |||||
| has_bias=self.has_bias)) | |||||
| weights = [] | |||||
| for i in range(num_layers): | |||||
| weight_size = (input_size_list[i] + hidden_size) * hidden_size * self.num_directions * 4 | |||||
| if self.has_bias: | |||||
| bias_size = self.num_directions * hidden_size * 4 * 2 | |||||
| weight_size = weight_size + bias_size | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i))) | |||||
| self.h, self.c = self.stack_lstm_default_state(batch_size, hidden_size, num_layers=num_layers, | |||||
| bidirectional=bidirectional) | |||||
| self.lstms = layers | |||||
| self.weight = ParameterTuple(tuple(weights)) | |||||
| if batch_norm: | |||||
| batch_norm_layer = [] | |||||
| for i in range(num_layers - 1): | |||||
| batch_norm_layer.append(nn.BatchNorm1d(hidden_size)) | |||||
| self.batch_norm_list = batch_norm_layer | |||||
| def stack_lstm_default_state(self, batch_size, hidden_size, num_layers, bidirectional): | |||||
| """init default input.""" | |||||
| num_directions = 2 if bidirectional else 1 | |||||
| h_list = c_list = [] | |||||
| for _ in range(num_layers): | |||||
| h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))) | |||||
| c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))) | |||||
| h, c = tuple(h_list), tuple(c_list) | |||||
| return h, c | |||||
| def construct(self, x): | |||||
| for i in range(self.num_layers): | |||||
| if self.is_batch_norm and i > 0: | |||||
| x = self.batch_norm_list[i - 1](x) | |||||
| x, _, _, _, _ = self.lstms[i](x, self.h[i], self.c[i], self.weight[i]) | |||||
| if self.bidirectional: | |||||
| size = self.shape_op(x) | |||||
| x = self.reshape_op(x, (size[0], size[1], 2, -1)) | |||||
| x = self.sum_op(x, 2) | |||||
| return x | |||||
| class DeepSpeechModel(nn.Cell): | |||||
| """ | |||||
| ResNet architecture. | |||||
| Args: | |||||
| batch_size(int): smaple_number of per step in training (default=128) | |||||
| rnn_type (str): rnn type to use (default="LSTM") | |||||
| labels (list): list containing all the possible characters to map to | |||||
| rnn_hidden_size(int): rnn hidden size | |||||
| nb_layers(int): number of rnn layers | |||||
| audio_conf: Config containing the sample rate, window and the window length/stride in seconds | |||||
| bidirectional(bool): use bidirectional rnn (default=True) | |||||
| """ | |||||
| def __init__(self, batch_size, labels, rnn_hidden_size, nb_layers, audio_conf, rnn_type='LSTM', bidirectional=True): | |||||
| super(DeepSpeechModel, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.hidden_size = rnn_hidden_size | |||||
| self.hidden_layers = nb_layers | |||||
| self.rnn_type = rnn_type | |||||
| self.audio_conf = audio_conf | |||||
| self.labels = labels | |||||
| self.bidirectional = bidirectional | |||||
| self.reshape_op = P.Reshape() | |||||
| self.shape_op = P.Shape() | |||||
| self.transpose_op = P.Transpose() | |||||
| self.add = P.TensorAdd() | |||||
| self.div = P.Div() | |||||
| sample_rate = self.audio_conf.sample_rate | |||||
| window_size = self.audio_conf.window_size | |||||
| num_classes = len(self.labels) | |||||
| self.conv = MaskConv() | |||||
| # This is to calculate | |||||
| self.pre, self.stride = self.get_conv_num() | |||||
| # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 | |||||
| rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1) | |||||
| rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) | |||||
| rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) | |||||
| rnn_input_size *= 32 | |||||
| self.RNN = BatchRNN(batch_size=self.batch_size, input_size=rnn_input_size, num_layers=nb_layers, | |||||
| hidden_size=rnn_hidden_size, bidirectional=bidirectional, batch_norm=False, | |||||
| rnn_type=self.rnn_type) | |||||
| fully_connected = nn.Dense(rnn_hidden_size, num_classes, has_bias=False) | |||||
| self.fc = SequenceWise(fully_connected) | |||||
| def construct(self, x, lengths): | |||||
| """ | |||||
| lengths is actually not used in this part since Mindspore does not support dynamic shape. | |||||
| """ | |||||
| output_lengths = self.get_seq_lens(lengths) | |||||
| x = self.conv(x, lengths) | |||||
| sizes = self.shape_op(x) | |||||
| x = self.reshape_op(x, (sizes[0], sizes[1] * sizes[2], sizes[3])) | |||||
| x = self.transpose_op(x, (2, 0, 1)) | |||||
| x = self.RNN(x) | |||||
| x = self.fc(x) | |||||
| return x, output_lengths | |||||
| def get_seq_lens(self, seq_len): | |||||
| """ | |||||
| Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable | |||||
| containing the size sequences that will be output by the network. | |||||
| """ | |||||
| for i in range(len(self.stride)): | |||||
| seq_len = self.add(self.div(self.add(seq_len, self.pre[i]), self.stride[i]), 1) | |||||
| return seq_len | |||||
| def get_conv_num(self): | |||||
| p, s = [], [] | |||||
| for _, cell in self.conv.cells_and_names(): | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| kernel_size = cell.kernel_size | |||||
| padding_1 = int((kernel_size[1] - 1) / 2) | |||||
| temp = 2 * padding_1 - cell.dilation[1] * (cell.kernel_size[1] - 1) - 1 | |||||
| p.append(temp) | |||||
| s.append(cell.stride[1]) | |||||
| return p, s | |||||
| class NetWithLossClass(nn.Cell): | |||||
| """ | |||||
| NetWithLossClass definition | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(NetWithLossClass, self).__init__(auto_prefix=False) | |||||
| self.loss = P.CTCLoss(ctc_merge_repeated=True) | |||||
| self.network = network | |||||
| self.ReduceMean_false = P.ReduceMean(keep_dims=False) | |||||
| self.squeeze_op = P.Squeeze(0) | |||||
| def construct(self, inputs, input_length, target_indices, label_values): | |||||
| predict, output_length = self.network(inputs, input_length) | |||||
| loss = self.loss(predict, target_indices, label_values, output_length) | |||||
| return self.ReduceMean_false(loss[0]) | |||||
| class PredictWithSoftmax(nn.Cell): | |||||
| """ | |||||
| PredictWithSoftmax | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(PredictWithSoftmax, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.inference_softmax = P.Softmax(axis=-1) | |||||
| self.transpose_op = P.Transpose() | |||||
| def construct(self, inputs, input_length): | |||||
| x, output_sizes = self.network(inputs, input_length) | |||||
| x = self.inference_softmax(x) | |||||
| x = self.transpose_op(x, (1, 0, 2)) | |||||
| return x, output_sizes | |||||
| @@ -0,0 +1,52 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| modify GreedyDecoder to adapt to MindSpore | |||||
| """ | |||||
| import numpy as np | |||||
| from deepspeech_pytorch.decoder import GreedyDecoder | |||||
| class MSGreedyDecoder(GreedyDecoder): | |||||
| """ | |||||
| GreedyDecoder used for MindSpore | |||||
| """ | |||||
| def process_string(self, sequence, size, remove_repetitions=False): | |||||
| """ | |||||
| process string | |||||
| """ | |||||
| string = '' | |||||
| offsets = [] | |||||
| for i in range(size): | |||||
| char = self.int_to_char[sequence[i].item()] | |||||
| if char != self.int_to_char[self.blank_index]: | |||||
| if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]: | |||||
| pass | |||||
| elif char == self.labels[self.space_index]: | |||||
| string += ' ' | |||||
| offsets.append(i) | |||||
| else: | |||||
| string = string + char | |||||
| offsets.append(i) | |||||
| return string, offsets | |||||
| def decode(self, probs, sizes=None): | |||||
| probs = probs.asnumpy() | |||||
| sizes = sizes.asnumpy() | |||||
| max_probs = np.argmax(probs, axis=-1) | |||||
| strings, offsets = self.convert_to_strings(max_probs, sizes, remove_repetitions=True, return_offsets=True) | |||||
| return strings, offsets | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """learning rate generator""" | |||||
| import numpy as np | |||||
| def get_lr(lr_init, total_epochs, steps_per_epoch): | |||||
| """ | |||||
| generate learning rate array | |||||
| Args: | |||||
| lr_init(float): init learning rate | |||||
| total_epochs(int): total epoch of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| Returns: | |||||
| np.array, learning rate array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| half_epoch = total_epochs // 2 | |||||
| for i in range(total_epochs * steps_per_epoch): | |||||
| if i < half_epoch: | |||||
| lr_each_step.append(lr_init) | |||||
| else: | |||||
| lr_each_step.append(lr_init / (1.1 ** (i - half_epoch))) | |||||
| learning_rate = np.array(lr_each_step).astype(np.float32) | |||||
| return learning_rate | |||||
| @@ -0,0 +1,103 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train_criteo.""" | |||||
| import os | |||||
| import json | |||||
| import argparse | |||||
| from mindspore import context, Tensor, ParameterTuple | |||||
| from mindspore.context import ParallelMode | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.nn.optim import Adam | |||||
| from mindspore.nn import TrainOneStepCell | |||||
| from mindspore.train import Model | |||||
| from src.deepspeech2 import DeepSpeechModel, NetWithLossClass | |||||
| from src.lr_generator import get_lr | |||||
| from src.callback import Monitor | |||||
| from src.config import train_config | |||||
| from src.dataset import create_dataset | |||||
| parser = argparse.ArgumentParser(description='DeepSpeech2 training') | |||||
| parser.add_argument('--pre_trained_model_path', type=str, default='', help='Pretrained checkpoint path') | |||||
| parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training') | |||||
| parser.add_argument('--bidirectional', action="store_false", default=True, help='Use bidirectional RNN') | |||||
| args = parser.parse_args() | |||||
| if __name__ == '__main__': | |||||
| rank_id = 0 | |||||
| group_size = 1 | |||||
| config = train_config | |||||
| if args.is_distributed: | |||||
| init('nccl') | |||||
| rank_id = get_rank() | |||||
| group_size = get_group_size() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False) | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| else: | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False) | |||||
| with open(config.DataConfig.labels_path) as label_file: | |||||
| labels = json.load(label_file) | |||||
| ds_train = create_dataset(audio_conf=config.DataConfig.SpectConfig, | |||||
| manifest_filepath=config.DataConfig.train_manifest, | |||||
| labels=labels, normalize=True, train_mode=True, | |||||
| batch_size=config.DataConfig.batch_size, rank=rank_id, group_size=group_size) | |||||
| steps_size = ds_train.get_dataset_size() | |||||
| lr = get_lr(lr_init=config.OptimConfig.learning_rate, total_epochs=config.TrainingConfig.epochs, | |||||
| steps_per_epoch=steps_size) | |||||
| lr = Tensor(lr) | |||||
| deepspeech_net = DeepSpeechModel(batch_size=config.DataConfig.batch_size, | |||||
| rnn_hidden_size=config.ModelConfig.hidden_size, | |||||
| nb_layers=config.ModelConfig.hidden_layers, | |||||
| labels=labels, | |||||
| rnn_type=config.ModelConfig.rnn_type, | |||||
| audio_conf=config.DataConfig.SpectConfig, | |||||
| bidirectional=True) | |||||
| loss_net = NetWithLossClass(deepspeech_net) | |||||
| weights = ParameterTuple(deepspeech_net.trainable_params()) | |||||
| optimizer = Adam(weights, learning_rate=config.OptimConfig.learning_rate, eps=config.OptimConfig.epsilon, | |||||
| loss_scale=config.OptimConfig.loss_scale) | |||||
| train_net = TrainOneStepCell(loss_net, optimizer) | |||||
| if args.pre_trained_model_path is not None: | |||||
| param_dict = load_checkpoint(args.pre_trained_model_path) | |||||
| load_param_into_net(train_net, param_dict) | |||||
| print('Successfully loading the pre-trained model') | |||||
| model = Model(train_net) | |||||
| lr_cb = Monitor(lr) | |||||
| callback_list = [lr_cb] | |||||
| if args.is_distributed: | |||||
| config.CheckpointConfig.ckpt_file_name_prefix = config.CheckpointConfig.ckpt_file_name_prefix + str(get_rank()) | |||||
| config.CheckpointConfig.ckpt_path = os.path.join(config.CheckpointConfig.ckpt_path, | |||||
| 'ckpt_' + str(get_rank()) + '/') | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=1, | |||||
| keep_checkpoint_max=config.CheckpointConfig.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix=config.CheckpointConfig.ckpt_file_name_prefix, | |||||
| directory=config.CheckpointConfig.ckpt_path, config=config_ck) | |||||
| callback_list.append(ckpt_cb) | |||||
| model.train(config.TrainingConfig.epochs, ds_train, callbacks=callback_list) | |||||