| @@ -0,0 +1,203 @@ | |||||
| # Contents | |||||
| - [Music Auto Tagging Description](#fcn-4-description) | |||||
| - [Model Architecture](#model-architecture) | |||||
| - [Features](#features) | |||||
| - [Mixed Precision](#mixed-precision) | |||||
| - [Environment Requirements](#environment-requirements) | |||||
| - [Quick Start](#quick-start) | |||||
| - [Script Description](#script-description) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Script Parameters](#script-parameters) | |||||
| - [Training Process](#training-process) | |||||
| - [Training](#training) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Evaluation](#evaluation) | |||||
| - [Model Description](#model-description) | |||||
| - [Performance](#performance) | |||||
| - [Evaluation Performance](#evaluation-performance) | |||||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||||
| # [Music Auto Tagging Description](#contents) | |||||
| This repository provides a script and recipe to train the Music Auto Tagging model to achieve state-of-the-art accuracy. | |||||
| [Paper](https://arxiv.org/abs/1606.00298): `"Keunwoo Choi, George Fazekas, and Mark Sandler, “Automatic tagging using deep convolutional neural networks,” in International Society of Music Information Retrieval Conference. ISMIR, 2016." | |||||
| # [Model Architecture](#contents) | |||||
| Music Auto Tagging is a convolutional neural network architecture, its name Music Auto Tagging comes from the fact that it has 4 layers. Its layers consists of Convolutional layers, Max Pooling layers, Activation layers, Fully connected layers. | |||||
| # [Features](#contents) | |||||
| ## Mixed Precision | |||||
| The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. | |||||
| For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. | |||||
| # [Environment Requirements](#contents) | |||||
| - Hardware(Ascend | |||||
| - If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||||
| - Framework | |||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) | |||||
| - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) | |||||
| # [Quick Start](#contents) | |||||
| After installing MindSpore via the official website, you can start training and evaluation as follows: | |||||
| ### 1. Download and preprocess the dataset | |||||
| 1. down load the classification dataset (for instance, MagnaTagATune Dataset, Million Song Dataset, etc) | |||||
| 2. Extract the dataset | |||||
| 3. The information file of each clip should contain the label and path. Please refer to the annotations_final.csv in MagnaTagATune Dataset. | |||||
| 4. The provided pre-processing script use MagnaTagATune Dataset as an example. Please modify the code accprding to your own need. | |||||
| ### 2. setup parameters (src/config.py) | |||||
| ### 3. Train | |||||
| after having your dataset, first convert the audio clip into mindrecord dataset by using the following codes | |||||
| ```shell | |||||
| python pre_process_data.py --device_id 0 | |||||
| ``` | |||||
| Then, you can start training the model by using the following codes | |||||
| ```shell | |||||
| SLOG_PRINT_TO_STDOUT=1 python train.py --device_id 0 | |||||
| ``` | |||||
| ### 4. Test | |||||
| Then you can test your model | |||||
| ```shell | |||||
| SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0 | |||||
| ``` | |||||
| # [Script Description](#contents) | |||||
| ## [Script and Sample Code](#contents) | |||||
| ``` | |||||
| ├── model_zoo | |||||
| ├── README.md // descriptions about all the models | |||||
| ├── music_auto_tagging | |||||
| ├── README.md // descriptions about googlenet | |||||
| ├── scripts | |||||
| │ ├──run_train.sh // shell script for distributed on Ascend | |||||
| │ ├──run_eval.sh // shell script for evaluation on Ascend | |||||
| │ ├──run_process_data.sh // shell script for convert audio clips to mindrecord | |||||
| ├── src | |||||
| │ ├──dataset.py // creating dataset | |||||
| │ ├──pre_process_data.py // pre-process dataset | |||||
| │ ├──musictagger.py // googlenet architecture | |||||
| │ ├──config.py // parameter configuration | |||||
| │ ├──loss.py // loss function | |||||
| │ ├──tag.txt // tag for each number | |||||
| ├── train.py // training script | |||||
| ├── eval.py // evaluation script | |||||
| ├── export.py // export model in air format | |||||
| ``` | |||||
| ## [Script Parameters](#contents) | |||||
| Parameters for both training and evaluation can be set in config.py | |||||
| - config for Music Auto tagging | |||||
| ```python | |||||
| 'num_classes': 50 # number of tagging classes | |||||
| 'num_consumer': 4 # file number for mindrecord | |||||
| 'get_npy': 1 # mode for converting to npy, default 1 in this case | |||||
| 'get_mindrecord': 1 # mode for converting npy file into mindrecord file,default 1 in this case | |||||
| 'audio_path': "/dev/data/Music_Tagger_Data/fea/" # path to audio clips | |||||
| 'npy_path': "/dev/data/Music_Tagger_Data/fea/" # path to numpy | |||||
| 'info_path': "/dev/data/Music_Tagger_Data/fea/" # path to info_name, which provide the label of each audio clips | |||||
| 'info_name': 'annotations_final.csv' # info_name | |||||
| 'device_target': 'Ascend' # device running the program | |||||
| 'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training | |||||
| 'mr_path': '/dev/data/Music_Tagger_Data/fea/' # path to mindrecord | |||||
| 'mr_name': ['train', 'val'] # mindrecord name | |||||
| 'pre_trained': False # whether training based on the pre-trained model | |||||
| 'lr': 0.0005 # learning rate | |||||
| 'batch_size': 32 # training batch size | |||||
| 'epoch_size': 10 # total training epochs | |||||
| 'loss_scale': 1024.0 # loss scale | |||||
| 'num_consumer': 4 # file number for mindrecord | |||||
| 'mixed_precision': False # if use mix precision calculation | |||||
| 'train_filename': 'train.mindrecord0' # file name of the train mindrecord data | |||||
| 'val_filename': 'val.mindrecord0' # file name of the evaluation mindrecord data | |||||
| 'data_dir': '/dev/data/Music_Tagger_Data/fea/' # directory of mindrecord data | |||||
| 'device_target': 'Ascend' # device running the program | |||||
| 'device_id': 0, # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training | |||||
| 'keep_checkpoint_max': 10, # only keep the last keep_checkpoint_max checkpoint | |||||
| 'save_step': 2000, # steps for saving checkpoint | |||||
| 'checkpoint_path': '/dev/data/Music_Tagger_Data/model/', # the absolute full path to save the checkpoint file | |||||
| 'prefix': 'MusicTagger', # prefix of checkpoint | |||||
| 'model_name': 'MusicTagger_3-50_543.ckpt', # checkpoint name | |||||
| ``` | |||||
| ## [Training Process](#contents) | |||||
| ### Training | |||||
| - running on Ascend | |||||
| ``` | |||||
| python train.py > train.log 2>&1 & | |||||
| ``` | |||||
| The python command above will run in the background, you can view the results through the file `train.log`. | |||||
| After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: | |||||
| ``` | |||||
| # grep "loss is " train.log | |||||
| epoch: 1 step: 100, loss is 0.23264095 | |||||
| epoch: 1 step: 200, loss is 0.2013525 | |||||
| ... | |||||
| ``` | |||||
| The model checkpoint will be saved in the set directory. | |||||
| ## [Evaluation Process](#contents) | |||||
| ### Evaluation | |||||
| # [Model Description](#contents) | |||||
| ## [Performance](#contents) | |||||
| ### Evaluation Performance | |||||
| | Parameters | Ascend | | |||||
| | -------------------------- | ----------------------------------------------------------- | | |||||
| | Model Version | FCN-4 | | |||||
| | Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | | |||||
| | uploaded Date | 09/11/2020 (month/day/year) | | |||||
| | MindSpore Version | r0.7.0 | | |||||
| | Training Parameters | epoch=10, steps=534, batch_size = 32, lr=0.005 | | |||||
| | Optimizer | Adam | | |||||
| | Loss Function | Binary cross entropy | | |||||
| | outputs | probability | | |||||
| | Loss | AUC 0.909 | | |||||
| | Speed | 1pc: 160 samples/sec; | | |||||
| | Total time | 1pc: 20 mins; | | |||||
| | Checkpoint for Fine tuning | 198.73M(.ckpt file) | | |||||
| | Scripts | [music_auto_tagging script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/audio/music_auto_tagging) | | |||||
| # [ModelZoo Homepage](#contents) | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -0,0 +1,137 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| ##############evaluate trained models################# | |||||
| python eval.py | |||||
| ''' | |||||
| import argparse | |||||
| import numpy as np | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.musictagger import MusicTaggerCNN | |||||
| from src.config import music_cfg as cfg | |||||
| from src.dataset import create_dataset | |||||
| def calculate_auc(labels_list, preds_list): | |||||
| """ | |||||
| The AUC calculation function | |||||
| Input: | |||||
| labels_list: list of true label | |||||
| preds_list: list of predicted label | |||||
| Outputs | |||||
| Float, means of AUC | |||||
| """ | |||||
| auc = [] | |||||
| n_bins = labels_list.shape[0] // 2 | |||||
| if labels_list.ndim == 1: | |||||
| labels_list = labels_list.reshape(-1, 1) | |||||
| preds_list = preds_list.reshape(-1, 1) | |||||
| for i in range(labels_list.shape[1]): | |||||
| labels = labels_list[:, i] | |||||
| preds = preds_list[:, i] | |||||
| postive_len = labels.sum() | |||||
| negative_len = labels.shape[0] - postive_len | |||||
| total_case = postive_len * negative_len | |||||
| positive_histogram = np.zeros((n_bins)) | |||||
| negative_histogram = np.zeros((n_bins)) | |||||
| bin_width = 1.0 / n_bins | |||||
| for j, _ in enumerate(labels): | |||||
| nth_bin = int(preds[j] // bin_width) | |||||
| if labels[j]: | |||||
| positive_histogram[nth_bin] = positive_histogram[nth_bin] + 1 | |||||
| else: | |||||
| negative_histogram[nth_bin] = negative_histogram[nth_bin] + 1 | |||||
| accumulated_negative = 0 | |||||
| satisfied_pair = 0 | |||||
| for k in range(n_bins): | |||||
| satisfied_pair += ( | |||||
| positive_histogram[k] * accumulated_negative + | |||||
| positive_histogram[k] * negative_histogram[k] * 0.5) | |||||
| accumulated_negative += negative_histogram[k] | |||||
| auc.append(satisfied_pair / total_case) | |||||
| return np.mean(auc) | |||||
| def val(net, data_dir, filename, num_consumer=4, batch=32): | |||||
| """ | |||||
| Validation function, estimate the performance of trained model | |||||
| Input: | |||||
| net: the trained neural network | |||||
| data_dir: path to the validation dataset | |||||
| filename: name of the validation dataset | |||||
| num_consumer: split number of validation dataset | |||||
| batch: validation batch size | |||||
| Outputs | |||||
| Float, AUC | |||||
| """ | |||||
| data_train = create_dataset(data_dir, filename, 32, ['feature', 'label'], | |||||
| num_consumer) | |||||
| data_train = data_train.create_tuple_iterator() | |||||
| res_pred = [] | |||||
| res_true = [] | |||||
| for data, label in data_train: | |||||
| x = net(Tensor(data, dtype=mstype.float32)) | |||||
| res_pred.append(x.asnumpy()) | |||||
| res_true.append(label.asnumpy()) | |||||
| res_pred = np.concatenate(res_pred, axis=0) | |||||
| res_true = np.concatenate(res_true, axis=0) | |||||
| auc = calculate_auc(res_true, res_pred) | |||||
| return auc | |||||
| def validation(net, model_path, data_dir, filename, num_consumer, batch): | |||||
| param_dict = load_checkpoint(model_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| auc = val(net, data_dir, filename, num_consumer, batch) | |||||
| return auc | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='Evaluate model') | |||||
| parser.add_argument('--device_id', | |||||
| type=int, | |||||
| help='device ID', | |||||
| default=None) | |||||
| args = parser.parse_args() | |||||
| if args.device_id is not None: | |||||
| context.set_context(device_target=cfg.device_target, | |||||
| mode=context.GRAPH_MODE, | |||||
| device_id=args.device_id) | |||||
| else: | |||||
| context.set_context(device_target=cfg.device_target, | |||||
| mode=context.GRAPH_MODE, | |||||
| device_id=cfg.device_id) | |||||
| network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], | |||||
| kernel_size=[3, 3, 3, 3, 3], | |||||
| padding=[0] * 5, | |||||
| maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], | |||||
| has_bias=True) | |||||
| network.set_train(False) | |||||
| auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir, | |||||
| cfg.val_filename, cfg.num_consumer, cfg.batch_size) | |||||
| print("=" * 10 + "Validation Peformance" + "=" * 10) | |||||
| print("AUC: {:.5f}".format(auc_val)) | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| ##############evaluate trained models################# | |||||
| python export.py | |||||
| ''' | |||||
| import numpy as np | |||||
| from mindspore.train.serialization import export | |||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.musictagger import MusicTaggerCNN | |||||
| from src.config import music_cfg as cfg | |||||
| if __name__ == "__main__": | |||||
| network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], | |||||
| kernel_size=[3, 3, 3, 3, 3], | |||||
| padding=[0] * 5, | |||||
| maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], | |||||
| has_bias=True) | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path + "/" + cfg.model_name) | |||||
| load_param_into_net(network, param_dict) | |||||
| input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32) | |||||
| export(network, | |||||
| Tensor(input_data), | |||||
| filename="{}/{}.air".format(cfg.checkpoint_path, | |||||
| cfg.model_name[:-5]), | |||||
| file_format="AIR") | |||||
| @@ -0,0 +1,18 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| export SLOG_PRINT_TO_STDOUT=1 | |||||
| python ../eval.py --device_id 0 | |||||
| @@ -0,0 +1,18 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| export SLOG_PRINT_TO_STDOUT=1 | |||||
| python ../src/pre_process_data.py --device_id 0 | |||||
| @@ -0,0 +1,18 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| export SLOG_PRINT_TO_STDOUT=1 | |||||
| python ../train.py --device_id 0 | |||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| __init__.py | |||||
| """ | |||||
| from . import musictagger | |||||
| from . import loss | |||||
| from . import dataset | |||||
| from . import config | |||||
| from . import pre_process_data | |||||
| @@ -0,0 +1,53 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py, eval.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| data_cfg = edict({ | |||||
| 'num_classes': 50, | |||||
| 'num_consumer': 4, | |||||
| 'get_npy': 1, | |||||
| 'get_mindrecord': 1, | |||||
| 'audio_path': "/dev/data/Music_Tagger_Data/fea/", | |||||
| 'npy_path': "/dev/data/Music_Tagger_Data/fea/", | |||||
| 'info_path': "/dev/data/Music_Tagger_Data/fea/", | |||||
| 'info_name': 'annotations_final.csv', | |||||
| 'device_target': 'Ascend', | |||||
| 'device_id': 0, | |||||
| 'mr_path': '/dev/data/Music_Tagger_Data/fea/', | |||||
| 'mr_name': ['train', 'val'], | |||||
| }) | |||||
| music_cfg = edict({ | |||||
| 'pre_trained': False, | |||||
| 'lr': 0.0005, | |||||
| 'batch_size': 32, | |||||
| 'epoch_size': 10, | |||||
| 'loss_scale': 1024.0, | |||||
| 'num_consumer': 4, | |||||
| 'mixed_precision': False, | |||||
| 'train_filename': 'train.mindrecord0', | |||||
| 'val_filename': 'val.mindrecord0', | |||||
| 'data_dir': '/dev/data/Music_Tagger_Data/fea/', | |||||
| 'device_target': 'Ascend', | |||||
| 'device_id': 0, | |||||
| 'keep_checkpoint_max': 10, | |||||
| 'save_step': 2000, | |||||
| 'checkpoint_path': '/dev/data/Music_Tagger_Data/model', | |||||
| 'prefix': 'MusicTagger', | |||||
| 'model_name': 'MusicTagger_3-50_543.ckpt', | |||||
| }) | |||||
| @@ -0,0 +1,30 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| '''python dataset.py''' | |||||
| import os | |||||
| import mindspore.dataset as ds | |||||
| def create_dataset(base_path, filename, batch_size, columns_list, | |||||
| num_consumer): | |||||
| """Create dataset""" | |||||
| path = os.path.join(base_path, filename) | |||||
| dtrain = ds.MindDataset(path, columns_list, num_consumer) | |||||
| dtrain = dtrain.shuffle(buffer_size=dtrain.get_dataset_size()) | |||||
| dtrain = dtrain.batch(batch_size, drop_remainder=True) | |||||
| return dtrain | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| define loss | |||||
| """ | |||||
| from mindspore import nn | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| class BCELoss(nn.Cell): | |||||
| """ | |||||
| BCELoss | |||||
| """ | |||||
| def __init__(self, record=None): | |||||
| super(BCELoss, self).__init__(record) | |||||
| self.sm_scalar = P.ScalarSummary() | |||||
| self.cast = P.Cast() | |||||
| self.record = record | |||||
| self.weight = None | |||||
| self.bce = P.BinaryCrossEntropy() | |||||
| def construct(self, input_data, target): | |||||
| target = self.cast(target, mstype.float32) | |||||
| loss = self.bce(input_data, target, self.weight) | |||||
| if self.record: | |||||
| self.sm_scalar("loss", loss) | |||||
| return loss | |||||
| @@ -0,0 +1,83 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| '''model''' | |||||
| from mindspore import nn | |||||
| from mindspore.ops import operations as P | |||||
| class MusicTaggerCNN(nn.Cell): | |||||
| """ | |||||
| Music Tagger CNN | |||||
| """ | |||||
| def __init__(self, in_classes, kernel_size, padding, maxpool, has_bias): | |||||
| super(MusicTaggerCNN, self).__init__() | |||||
| self.in_classes = in_classes | |||||
| self.kernel_size = kernel_size | |||||
| self.maxpool = maxpool | |||||
| self.padding = padding | |||||
| self.has_bias = has_bias | |||||
| # build model | |||||
| self.conv1 = nn.Conv2d(self.in_classes[0], self.in_classes[1], | |||||
| self.kernel_size[0]) | |||||
| self.conv2 = nn.Conv2d(self.in_classes[1], self.in_classes[2], | |||||
| self.kernel_size[1]) | |||||
| self.conv3 = nn.Conv2d(self.in_classes[2], self.in_classes[3], | |||||
| self.kernel_size[2]) | |||||
| self.conv4 = nn.Conv2d(self.in_classes[3], self.in_classes[4], | |||||
| self.kernel_size[3]) | |||||
| self.bn1 = nn.BatchNorm2d(self.in_classes[1]) | |||||
| self.bn2 = nn.BatchNorm2d(self.in_classes[2]) | |||||
| self.bn3 = nn.BatchNorm2d(self.in_classes[3]) | |||||
| self.bn4 = nn.BatchNorm2d(self.in_classes[4]) | |||||
| self.pool1 = nn.MaxPool2d(maxpool[0], maxpool[0]) | |||||
| self.pool2 = nn.MaxPool2d(maxpool[1], maxpool[1]) | |||||
| self.pool3 = nn.MaxPool2d(maxpool[2], maxpool[2]) | |||||
| self.pool4 = nn.MaxPool2d(maxpool[3], maxpool[3]) | |||||
| self.poolreduce = P.ReduceMax(keep_dims=False) | |||||
| self.Act = nn.ReLU() | |||||
| self.flatten = nn.Flatten() | |||||
| self.dense = nn.Dense(2048, 50, activation='sigmoid') | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| def construct(self, input_data): | |||||
| """ | |||||
| Music Tagger CNN | |||||
| """ | |||||
| x = self.conv1(input_data) | |||||
| x = self.bn1(x) | |||||
| x = self.Act(x) | |||||
| x = self.pool1(x) | |||||
| x = self.conv2(x) | |||||
| x = self.bn2(x) | |||||
| x = self.Act(x) | |||||
| x = self.pool2(x) | |||||
| x = self.conv3(x) | |||||
| x = self.bn3(x) | |||||
| x = self.Act(x) | |||||
| x = self.pool3(x) | |||||
| x = self.conv4(x) | |||||
| x = self.bn4(x) | |||||
| x = self.Act(x) | |||||
| x = self.poolreduce(x, (2, 3)) | |||||
| x = self.flatten(x) | |||||
| x = self.dense(x) | |||||
| return x | |||||
| @@ -0,0 +1,226 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| '''python dataset.py''' | |||||
| import os | |||||
| import argparse | |||||
| import pandas as pd | |||||
| import numpy as np | |||||
| import librosa | |||||
| from mindspore.mindrecord import FileWriter | |||||
| from mindspore import context | |||||
| from src.config import data_cfg as cfg | |||||
| def compute_melgram(audio_path, save_path='', filename='', save_npy=True): | |||||
| """ | |||||
| extract melgram feature from the audio and save as numpy array | |||||
| Args: | |||||
| audio_path (str): path to the audio clip. | |||||
| save_path (str): path to save the numpy array. | |||||
| filename (str): filename of the audio clip. | |||||
| Returns: | |||||
| numpy array. | |||||
| """ | |||||
| SR = 12000 | |||||
| N_FFT = 512 | |||||
| N_MELS = 96 | |||||
| HOP_LEN = 256 | |||||
| DURA = 29.12 # to make it 1366 frame.. | |||||
| src, _ = librosa.load(audio_path, sr=SR) # whole signal | |||||
| n_sample = src.shape[0] | |||||
| n_sample_fit = int(DURA * SR) | |||||
| if n_sample < n_sample_fit: # if too short | |||||
| src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,)))) | |||||
| elif n_sample > n_sample_fit: # if too long | |||||
| src = src[(n_sample - n_sample_fit) // 2:(n_sample + n_sample_fit) // | |||||
| 2] | |||||
| logam = librosa.core.amplitude_to_db | |||||
| melgram = librosa.feature.melspectrogram | |||||
| ret = logam( | |||||
| melgram(y=src, sr=SR, hop_length=HOP_LEN, n_fft=N_FFT, n_mels=N_MELS)) | |||||
| ret = ret[np.newaxis, np.newaxis, :] | |||||
| if save_npy: | |||||
| save_path = save_path + filename[:-4] + '.npy' | |||||
| np.save(save_path, ret) | |||||
| return ret | |||||
| def get_data(features_data, labels_data): | |||||
| data_list = [] | |||||
| for i, (label, feature) in enumerate(zip(labels_data, features_data)): | |||||
| data_json = {"id": i, "feature": feature, "label": label} | |||||
| data_list.append(data_json) | |||||
| return data_list | |||||
| def convert(s): | |||||
| if s.isdigit(): | |||||
| return int(s) | |||||
| return s | |||||
| def GetLabel(info_path, info_name): | |||||
| """ | |||||
| separate dataset into training set and validation set | |||||
| Args: | |||||
| info_path (str): path to the information file. | |||||
| info_name (str): name of the information file. | |||||
| """ | |||||
| T = [] | |||||
| with open(info_path + '/' + info_name, 'rb') as info: | |||||
| data = info.readline() | |||||
| while data: | |||||
| T.append([ | |||||
| convert(i[1:-1]) | |||||
| for i in data.strip().decode('utf-8').split("\t") | |||||
| ]) | |||||
| data = info.readline() | |||||
| annotation = pd.DataFrame(T[1:], columns=T[0]) | |||||
| count = [] | |||||
| for i in annotation.columns[1:-2]: | |||||
| count.append([annotation[i].sum() / len(annotation), i]) | |||||
| count = sorted(count) | |||||
| full_label = [] | |||||
| for i in count[-50:]: | |||||
| full_label.append(i[1]) | |||||
| out = [] | |||||
| for i in T[1:]: | |||||
| index = [k for k, x in enumerate(i) if x == 1] | |||||
| label = [T[0][k] for k in index] | |||||
| L = [str(0) for k in range(50)] | |||||
| L.append(i[-1]) | |||||
| for j in label: | |||||
| if j in full_label: | |||||
| ind = full_label.index(j) | |||||
| L[ind] = '1' | |||||
| out.append(L) | |||||
| out = np.array(out) | |||||
| Train = [] | |||||
| Val = [] | |||||
| for i in out: | |||||
| if np.random.rand() > 0.2: | |||||
| Train.append(i) | |||||
| else: | |||||
| Val.append(i) | |||||
| np.savetxt("{}/music_tagging_train_tmp.csv".format(info_path), | |||||
| np.array(Train), | |||||
| fmt='%s', | |||||
| delimiter=',') | |||||
| np.savetxt("{}/music_tagging_val_tmp.csv".format(info_path), | |||||
| np.array(Val), | |||||
| fmt='%s', | |||||
| delimiter=',') | |||||
| def generator_md(info_name, file_path, num_classes): | |||||
| """ | |||||
| generate numpy array from features of all audio clips | |||||
| Args: | |||||
| info_path (str): path to the information file. | |||||
| file_path (str): path to the npy files. | |||||
| Returns: | |||||
| 2 numpy array. | |||||
| """ | |||||
| df = pd.read_csv(info_name, header=None) | |||||
| df.columns = [str(i) for i in range(num_classes)] + ["mp3_path"] | |||||
| data = [] | |||||
| label = [] | |||||
| for i in range(len(df)): | |||||
| try: | |||||
| data.append( | |||||
| np.load(file_path + df.mp3_path.values[i][:-4] + | |||||
| '.npy').reshape(1, 96, 1366)) | |||||
| label.append(np.array(df[df.columns[:-1]][i:i + 1])[0]) | |||||
| except FileNotFoundError: | |||||
| print("Exception occurred in generator_md.") | |||||
| return np.array(data), np.array(label, dtype=np.int32) | |||||
| def convert_to_mindrecord(info_name, file_path, store_path, mr_name, | |||||
| num_classes): | |||||
| """ convert dataset to mindrecord """ | |||||
| num_shard = 4 | |||||
| data, label = generator_md(info_name, file_path, num_classes) | |||||
| schema_json = { | |||||
| "id": { | |||||
| "type": "int32" | |||||
| }, | |||||
| "feature": { | |||||
| "type": "float32", | |||||
| "shape": [1, 96, 1366] | |||||
| }, | |||||
| "label": { | |||||
| "type": "int32", | |||||
| "shape": [num_classes] | |||||
| } | |||||
| } | |||||
| writer = FileWriter( | |||||
| os.path.join(store_path, '{}.mindrecord'.format(mr_name)), num_shard) | |||||
| datax = get_data(data, label) | |||||
| writer.add_schema(schema_json, "music_tagger_schema") | |||||
| writer.add_index(["id"]) | |||||
| writer.write_raw_data(datax) | |||||
| writer.commit() | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='get feature') | |||||
| parser.add_argument('--device_id', | |||||
| type=int, | |||||
| help='device ID', | |||||
| default=None) | |||||
| args = parser.parse_args() | |||||
| if cfg.get_npy: | |||||
| GetLabel(cfg.info_path, cfg.info_name) | |||||
| dirname = os.listdir(cfg.audio_path) | |||||
| for d in dirname: | |||||
| file_name = os.listdir("{}/{}".format(cfg.audio_path, d)) | |||||
| if not os.path.isdir("{}/{}".format(cfg.npy_path, d)): | |||||
| os.mkdir("{}/{}".format(cfg.npy_path, d)) | |||||
| for f in file_name: | |||||
| compute_melgram("{}/{}/{}".format(cfg.audio_path, d, f), | |||||
| "{}/{}/".format(cfg.npy_path, d), f) | |||||
| if cfg.get_mindrecord: | |||||
| if args.device_id is not None: | |||||
| context.set_context(device_target='Ascend', | |||||
| mode=context.GRAPH_MODE, | |||||
| device_id=args.device_id) | |||||
| else: | |||||
| context.set_context(device_target='Ascend', | |||||
| mode=context.GRAPH_MODE, | |||||
| device_id=cfg.device_id) | |||||
| for cmn in cfg.mr_nam: | |||||
| if cmn in ['train', 'val']: | |||||
| convert_to_mindrecord('music_tagging_{}_tmp.csv'.format(cmn), | |||||
| cfg.npy_path, cfg.mr_path, cmn, | |||||
| cfg.num_classes) | |||||
| @@ -0,0 +1,50 @@ | |||||
| choral | |||||
| female voice | |||||
| metal | |||||
| country | |||||
| weird | |||||
| no voice | |||||
| cello | |||||
| harp | |||||
| beats | |||||
| female vocal | |||||
| male voice | |||||
| dance | |||||
| new age | |||||
| voice | |||||
| choir | |||||
| classic | |||||
| man | |||||
| solo | |||||
| sitar | |||||
| soft | |||||
| no vocal | |||||
| pop | |||||
| male vocal | |||||
| woman | |||||
| flute | |||||
| quiet | |||||
| loud | |||||
| harpsichord | |||||
| no vocals | |||||
| vocals | |||||
| singing | |||||
| male | |||||
| opera | |||||
| indian | |||||
| female | |||||
| synth | |||||
| vocal | |||||
| violin | |||||
| beat | |||||
| ambient | |||||
| piano | |||||
| fast | |||||
| rock | |||||
| electronic | |||||
| drums | |||||
| strings | |||||
| techno | |||||
| slow | |||||
| classical | |||||
| guitar | |||||
| @@ -0,0 +1,109 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| ##############train models################# | |||||
| python train.py | |||||
| ''' | |||||
| import argparse | |||||
| from mindspore import context, nn | |||||
| from mindspore.train import Model | |||||
| from mindspore.common import set_seed | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from src.dataset import create_dataset | |||||
| from src.musictagger import MusicTaggerCNN | |||||
| from src.loss import BCELoss | |||||
| from src.config import music_cfg as cfg | |||||
| def train(model, dataset_direct, filename, columns_list, num_consumer=4, | |||||
| batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50, | |||||
| prefix="model", directory='./'): | |||||
| """ | |||||
| train network | |||||
| """ | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, | |||||
| keep_checkpoint_max=keep_checkpoint_max) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=prefix, | |||||
| directory=directory, | |||||
| config=config_ck) | |||||
| data_train = create_dataset(dataset_direct, filename, batch, columns_list, | |||||
| num_consumer) | |||||
| model.train(epoch, | |||||
| data_train, | |||||
| callbacks=[ | |||||
| ckpoint_cb, | |||||
| LossMonitor(per_print_times=181), | |||||
| TimeMonitor() | |||||
| ], | |||||
| dataset_sink_mode=True) | |||||
| if __name__ == "__main__": | |||||
| set_seed(1) | |||||
| parser = argparse.ArgumentParser(description='Train model') | |||||
| parser.add_argument('--device_id', | |||||
| type=int, | |||||
| help='device ID', | |||||
| default=None) | |||||
| args = parser.parse_args() | |||||
| if args.device_id is not None: | |||||
| context.set_context(device_target='Ascend', | |||||
| mode=context.GRAPH_MODE, | |||||
| device_id=args.device_id) | |||||
| else: | |||||
| context.set_context(device_target='Ascend', | |||||
| mode=context.GRAPH_MODE, | |||||
| device_id=cfg.device_id) | |||||
| context.set_context(enable_auto_mixed_precision=cfg.mixed_precision) | |||||
| network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], | |||||
| kernel_size=[3, 3, 3, 3, 3], | |||||
| padding=[0] * 5, | |||||
| maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], | |||||
| has_bias=True) | |||||
| if cfg.pre_trained: | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path + '/' + | |||||
| cfg.model_name) | |||||
| load_param_into_net(network, param_dict) | |||||
| net_loss = BCELoss() | |||||
| network.set_train(True) | |||||
| net_opt = nn.Adam(params=network.trainable_params(), | |||||
| learning_rate=cfg.lr, | |||||
| loss_scale=cfg.loss_scale) | |||||
| loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale, | |||||
| drop_overflow_update=False) | |||||
| net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager) | |||||
| train(model=net_model, | |||||
| dataset_direct=cfg.data_dir, | |||||
| filename=cfg.train_filename, | |||||
| columns_list=['feature', 'label'], | |||||
| num_consumer=cfg.num_consumer, | |||||
| batch=cfg.batch_size, | |||||
| epoch=cfg.epoch_size, | |||||
| save_checkpoint_steps=cfg.save_step, | |||||
| keep_checkpoint_max=cfg.keep_checkpoint_max, | |||||
| prefix=cfg.prefix, | |||||
| directory=cfg.checkpoint_path + "_{}".format(cfg.device_id)) | |||||
| print("train success") | |||||