| @@ -0,0 +1,203 @@ | |||
| # Contents | |||
| - [Music Auto Tagging Description](#fcn-4-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Features](#features) | |||
| - [Mixed Precision](#mixed-precision) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Quick Start](#quick-start) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| - [Training Process](#training-process) | |||
| - [Training](#training) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Evaluation Performance](#evaluation-performance) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [Music Auto Tagging Description](#contents) | |||
| This repository provides a script and recipe to train the Music Auto Tagging model to achieve state-of-the-art accuracy. | |||
| [Paper](https://arxiv.org/abs/1606.00298): `"Keunwoo Choi, George Fazekas, and Mark Sandler, “Automatic tagging using deep convolutional neural networks,” in International Society of Music Information Retrieval Conference. ISMIR, 2016." | |||
| # [Model Architecture](#contents) | |||
| Music Auto Tagging is a convolutional neural network architecture, its name Music Auto Tagging comes from the fact that it has 4 layers. Its layers consists of Convolutional layers, Max Pooling layers, Activation layers, Fully connected layers. | |||
| # [Features](#contents) | |||
| ## Mixed Precision | |||
| The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. | |||
| For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend | |||
| - If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) | |||
| - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) | |||
| # [Quick Start](#contents) | |||
| After installing MindSpore via the official website, you can start training and evaluation as follows: | |||
| ### 1. Download and preprocess the dataset | |||
| 1. down load the classification dataset (for instance, MagnaTagATune Dataset, Million Song Dataset, etc) | |||
| 2. Extract the dataset | |||
| 3. The information file of each clip should contain the label and path. Please refer to the annotations_final.csv in MagnaTagATune Dataset. | |||
| 4. The provided pre-processing script use MagnaTagATune Dataset as an example. Please modify the code accprding to your own need. | |||
| ### 2. setup parameters (src/config.py) | |||
| ### 3. Train | |||
| after having your dataset, first convert the audio clip into mindrecord dataset by using the following codes | |||
| ```shell | |||
| python pre_process_data.py --device_id 0 | |||
| ``` | |||
| Then, you can start training the model by using the following codes | |||
| ```shell | |||
| SLOG_PRINT_TO_STDOUT=1 python train.py --device_id 0 | |||
| ``` | |||
| ### 4. Test | |||
| Then you can test your model | |||
| ```shell | |||
| SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0 | |||
| ``` | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| ``` | |||
| ├── model_zoo | |||
| ├── README.md // descriptions about all the models | |||
| ├── music_auto_tagging | |||
| ├── README.md // descriptions about googlenet | |||
| ├── scripts | |||
| │ ├──run_train.sh // shell script for distributed on Ascend | |||
| │ ├──run_eval.sh // shell script for evaluation on Ascend | |||
| │ ├──run_process_data.sh // shell script for convert audio clips to mindrecord | |||
| ├── src | |||
| │ ├──dataset.py // creating dataset | |||
| │ ├──pre_process_data.py // pre-process dataset | |||
| │ ├──musictagger.py // googlenet architecture | |||
| │ ├──config.py // parameter configuration | |||
| │ ├──loss.py // loss function | |||
| │ ├──tag.txt // tag for each number | |||
| ├── train.py // training script | |||
| ├── eval.py // evaluation script | |||
| ├── export.py // export model in air format | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| Parameters for both training and evaluation can be set in config.py | |||
| - config for Music Auto tagging | |||
| ```python | |||
| 'num_classes': 50 # number of tagging classes | |||
| 'num_consumer': 4 # file number for mindrecord | |||
| 'get_npy': 1 # mode for converting to npy, default 1 in this case | |||
| 'get_mindrecord': 1 # mode for converting npy file into mindrecord file,default 1 in this case | |||
| 'audio_path': "/dev/data/Music_Tagger_Data/fea/" # path to audio clips | |||
| 'npy_path': "/dev/data/Music_Tagger_Data/fea/" # path to numpy | |||
| 'info_path': "/dev/data/Music_Tagger_Data/fea/" # path to info_name, which provide the label of each audio clips | |||
| 'info_name': 'annotations_final.csv' # info_name | |||
| 'device_target': 'Ascend' # device running the program | |||
| 'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training | |||
| 'mr_path': '/dev/data/Music_Tagger_Data/fea/' # path to mindrecord | |||
| 'mr_name': ['train', 'val'] # mindrecord name | |||
| 'pre_trained': False # whether training based on the pre-trained model | |||
| 'lr': 0.0005 # learning rate | |||
| 'batch_size': 32 # training batch size | |||
| 'epoch_size': 10 # total training epochs | |||
| 'loss_scale': 1024.0 # loss scale | |||
| 'num_consumer': 4 # file number for mindrecord | |||
| 'mixed_precision': False # if use mix precision calculation | |||
| 'train_filename': 'train.mindrecord0' # file name of the train mindrecord data | |||
| 'val_filename': 'val.mindrecord0' # file name of the evaluation mindrecord data | |||
| 'data_dir': '/dev/data/Music_Tagger_Data/fea/' # directory of mindrecord data | |||
| 'device_target': 'Ascend' # device running the program | |||
| 'device_id': 0, # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training | |||
| 'keep_checkpoint_max': 10, # only keep the last keep_checkpoint_max checkpoint | |||
| 'save_step': 2000, # steps for saving checkpoint | |||
| 'checkpoint_path': '/dev/data/Music_Tagger_Data/model/', # the absolute full path to save the checkpoint file | |||
| 'prefix': 'MusicTagger', # prefix of checkpoint | |||
| 'model_name': 'MusicTagger_3-50_543.ckpt', # checkpoint name | |||
| ``` | |||
| ## [Training Process](#contents) | |||
| ### Training | |||
| - running on Ascend | |||
| ``` | |||
| python train.py > train.log 2>&1 & | |||
| ``` | |||
| The python command above will run in the background, you can view the results through the file `train.log`. | |||
| After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: | |||
| ``` | |||
| # grep "loss is " train.log | |||
| epoch: 1 step: 100, loss is 0.23264095 | |||
| epoch: 1 step: 200, loss is 0.2013525 | |||
| ... | |||
| ``` | |||
| The model checkpoint will be saved in the set directory. | |||
| ## [Evaluation Process](#contents) | |||
| ### Evaluation | |||
| # [Model Description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Evaluation Performance | |||
| | Parameters | Ascend | | |||
| | -------------------------- | ----------------------------------------------------------- | | |||
| | Model Version | FCN-4 | | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | | |||
| | uploaded Date | 09/11/2020 (month/day/year) | | |||
| | MindSpore Version | r0.7.0 | | |||
| | Training Parameters | epoch=10, steps=534, batch_size = 32, lr=0.005 | | |||
| | Optimizer | Adam | | |||
| | Loss Function | Binary cross entropy | | |||
| | outputs | probability | | |||
| | Loss | AUC 0.909 | | |||
| | Speed | 1pc: 160 samples/sec; | | |||
| | Total time | 1pc: 20 mins; | | |||
| | Checkpoint for Fine tuning | 198.73M(.ckpt file) | | |||
| | Scripts | [music_auto_tagging script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/audio/music_auto_tagging) | | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,137 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ''' | |||
| ##############evaluate trained models################# | |||
| python eval.py | |||
| ''' | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.musictagger import MusicTaggerCNN | |||
| from src.config import music_cfg as cfg | |||
| from src.dataset import create_dataset | |||
| def calculate_auc(labels_list, preds_list): | |||
| """ | |||
| The AUC calculation function | |||
| Input: | |||
| labels_list: list of true label | |||
| preds_list: list of predicted label | |||
| Outputs | |||
| Float, means of AUC | |||
| """ | |||
| auc = [] | |||
| n_bins = labels_list.shape[0] // 2 | |||
| if labels_list.ndim == 1: | |||
| labels_list = labels_list.reshape(-1, 1) | |||
| preds_list = preds_list.reshape(-1, 1) | |||
| for i in range(labels_list.shape[1]): | |||
| labels = labels_list[:, i] | |||
| preds = preds_list[:, i] | |||
| postive_len = labels.sum() | |||
| negative_len = labels.shape[0] - postive_len | |||
| total_case = postive_len * negative_len | |||
| positive_histogram = np.zeros((n_bins)) | |||
| negative_histogram = np.zeros((n_bins)) | |||
| bin_width = 1.0 / n_bins | |||
| for j, _ in enumerate(labels): | |||
| nth_bin = int(preds[j] // bin_width) | |||
| if labels[j]: | |||
| positive_histogram[nth_bin] = positive_histogram[nth_bin] + 1 | |||
| else: | |||
| negative_histogram[nth_bin] = negative_histogram[nth_bin] + 1 | |||
| accumulated_negative = 0 | |||
| satisfied_pair = 0 | |||
| for k in range(n_bins): | |||
| satisfied_pair += ( | |||
| positive_histogram[k] * accumulated_negative + | |||
| positive_histogram[k] * negative_histogram[k] * 0.5) | |||
| accumulated_negative += negative_histogram[k] | |||
| auc.append(satisfied_pair / total_case) | |||
| return np.mean(auc) | |||
| def val(net, data_dir, filename, num_consumer=4, batch=32): | |||
| """ | |||
| Validation function, estimate the performance of trained model | |||
| Input: | |||
| net: the trained neural network | |||
| data_dir: path to the validation dataset | |||
| filename: name of the validation dataset | |||
| num_consumer: split number of validation dataset | |||
| batch: validation batch size | |||
| Outputs | |||
| Float, AUC | |||
| """ | |||
| data_train = create_dataset(data_dir, filename, 32, ['feature', 'label'], | |||
| num_consumer) | |||
| data_train = data_train.create_tuple_iterator() | |||
| res_pred = [] | |||
| res_true = [] | |||
| for data, label in data_train: | |||
| x = net(Tensor(data, dtype=mstype.float32)) | |||
| res_pred.append(x.asnumpy()) | |||
| res_true.append(label.asnumpy()) | |||
| res_pred = np.concatenate(res_pred, axis=0) | |||
| res_true = np.concatenate(res_true, axis=0) | |||
| auc = calculate_auc(res_true, res_pred) | |||
| return auc | |||
| def validation(net, model_path, data_dir, filename, num_consumer, batch): | |||
| param_dict = load_checkpoint(model_path) | |||
| load_param_into_net(net, param_dict) | |||
| auc = val(net, data_dir, filename, num_consumer, batch) | |||
| return auc | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='Evaluate model') | |||
| parser.add_argument('--device_id', | |||
| type=int, | |||
| help='device ID', | |||
| default=None) | |||
| args = parser.parse_args() | |||
| if args.device_id is not None: | |||
| context.set_context(device_target=cfg.device_target, | |||
| mode=context.GRAPH_MODE, | |||
| device_id=args.device_id) | |||
| else: | |||
| context.set_context(device_target=cfg.device_target, | |||
| mode=context.GRAPH_MODE, | |||
| device_id=cfg.device_id) | |||
| network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], | |||
| kernel_size=[3, 3, 3, 3, 3], | |||
| padding=[0] * 5, | |||
| maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], | |||
| has_bias=True) | |||
| network.set_train(False) | |||
| auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir, | |||
| cfg.val_filename, cfg.num_consumer, cfg.batch_size) | |||
| print("=" * 10 + "Validation Peformance" + "=" * 10) | |||
| print("AUC: {:.5f}".format(auc_val)) | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ''' | |||
| ##############evaluate trained models################# | |||
| python export.py | |||
| ''' | |||
| import numpy as np | |||
| from mindspore.train.serialization import export | |||
| from mindspore import Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.musictagger import MusicTaggerCNN | |||
| from src.config import music_cfg as cfg | |||
| if __name__ == "__main__": | |||
| network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], | |||
| kernel_size=[3, 3, 3, 3, 3], | |||
| padding=[0] * 5, | |||
| maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], | |||
| has_bias=True) | |||
| param_dict = load_checkpoint(cfg.checkpoint_path + "/" + cfg.model_name) | |||
| load_param_into_net(network, param_dict) | |||
| input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32) | |||
| export(network, | |||
| Tensor(input_data), | |||
| filename="{}/{}.air".format(cfg.checkpoint_path, | |||
| cfg.model_name[:-5]), | |||
| file_format="AIR") | |||
| @@ -0,0 +1,18 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export SLOG_PRINT_TO_STDOUT=1 | |||
| python ../eval.py --device_id 0 | |||
| @@ -0,0 +1,18 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export SLOG_PRINT_TO_STDOUT=1 | |||
| python ../src/pre_process_data.py --device_id 0 | |||
| @@ -0,0 +1,18 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export SLOG_PRINT_TO_STDOUT=1 | |||
| python ../train.py --device_id 0 | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| __init__.py | |||
| """ | |||
| from . import musictagger | |||
| from . import loss | |||
| from . import dataset | |||
| from . import config | |||
| from . import pre_process_data | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in train.py, eval.py | |||
| """ | |||
| from easydict import EasyDict as edict | |||
| data_cfg = edict({ | |||
| 'num_classes': 50, | |||
| 'num_consumer': 4, | |||
| 'get_npy': 1, | |||
| 'get_mindrecord': 1, | |||
| 'audio_path': "/dev/data/Music_Tagger_Data/fea/", | |||
| 'npy_path': "/dev/data/Music_Tagger_Data/fea/", | |||
| 'info_path': "/dev/data/Music_Tagger_Data/fea/", | |||
| 'info_name': 'annotations_final.csv', | |||
| 'device_target': 'Ascend', | |||
| 'device_id': 0, | |||
| 'mr_path': '/dev/data/Music_Tagger_Data/fea/', | |||
| 'mr_name': ['train', 'val'], | |||
| }) | |||
| music_cfg = edict({ | |||
| 'pre_trained': False, | |||
| 'lr': 0.0005, | |||
| 'batch_size': 32, | |||
| 'epoch_size': 10, | |||
| 'loss_scale': 1024.0, | |||
| 'num_consumer': 4, | |||
| 'mixed_precision': False, | |||
| 'train_filename': 'train.mindrecord0', | |||
| 'val_filename': 'val.mindrecord0', | |||
| 'data_dir': '/dev/data/Music_Tagger_Data/fea/', | |||
| 'device_target': 'Ascend', | |||
| 'device_id': 0, | |||
| 'keep_checkpoint_max': 10, | |||
| 'save_step': 2000, | |||
| 'checkpoint_path': '/dev/data/Music_Tagger_Data/model', | |||
| 'prefix': 'MusicTagger', | |||
| 'model_name': 'MusicTagger_3-50_543.ckpt', | |||
| }) | |||
| @@ -0,0 +1,30 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| '''python dataset.py''' | |||
| import os | |||
| import mindspore.dataset as ds | |||
| def create_dataset(base_path, filename, batch_size, columns_list, | |||
| num_consumer): | |||
| """Create dataset""" | |||
| path = os.path.join(base_path, filename) | |||
| dtrain = ds.MindDataset(path, columns_list, num_consumer) | |||
| dtrain = dtrain.shuffle(buffer_size=dtrain.get_dataset_size()) | |||
| dtrain = dtrain.batch(batch_size, drop_remainder=True) | |||
| return dtrain | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| define loss | |||
| """ | |||
| from mindspore import nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| class BCELoss(nn.Cell): | |||
| """ | |||
| BCELoss | |||
| """ | |||
| def __init__(self, record=None): | |||
| super(BCELoss, self).__init__(record) | |||
| self.sm_scalar = P.ScalarSummary() | |||
| self.cast = P.Cast() | |||
| self.record = record | |||
| self.weight = None | |||
| self.bce = P.BinaryCrossEntropy() | |||
| def construct(self, input_data, target): | |||
| target = self.cast(target, mstype.float32) | |||
| loss = self.bce(input_data, target, self.weight) | |||
| if self.record: | |||
| self.sm_scalar("loss", loss) | |||
| return loss | |||
| @@ -0,0 +1,83 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| '''model''' | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| class MusicTaggerCNN(nn.Cell): | |||
| """ | |||
| Music Tagger CNN | |||
| """ | |||
| def __init__(self, in_classes, kernel_size, padding, maxpool, has_bias): | |||
| super(MusicTaggerCNN, self).__init__() | |||
| self.in_classes = in_classes | |||
| self.kernel_size = kernel_size | |||
| self.maxpool = maxpool | |||
| self.padding = padding | |||
| self.has_bias = has_bias | |||
| # build model | |||
| self.conv1 = nn.Conv2d(self.in_classes[0], self.in_classes[1], | |||
| self.kernel_size[0]) | |||
| self.conv2 = nn.Conv2d(self.in_classes[1], self.in_classes[2], | |||
| self.kernel_size[1]) | |||
| self.conv3 = nn.Conv2d(self.in_classes[2], self.in_classes[3], | |||
| self.kernel_size[2]) | |||
| self.conv4 = nn.Conv2d(self.in_classes[3], self.in_classes[4], | |||
| self.kernel_size[3]) | |||
| self.bn1 = nn.BatchNorm2d(self.in_classes[1]) | |||
| self.bn2 = nn.BatchNorm2d(self.in_classes[2]) | |||
| self.bn3 = nn.BatchNorm2d(self.in_classes[3]) | |||
| self.bn4 = nn.BatchNorm2d(self.in_classes[4]) | |||
| self.pool1 = nn.MaxPool2d(maxpool[0], maxpool[0]) | |||
| self.pool2 = nn.MaxPool2d(maxpool[1], maxpool[1]) | |||
| self.pool3 = nn.MaxPool2d(maxpool[2], maxpool[2]) | |||
| self.pool4 = nn.MaxPool2d(maxpool[3], maxpool[3]) | |||
| self.poolreduce = P.ReduceMax(keep_dims=False) | |||
| self.Act = nn.ReLU() | |||
| self.flatten = nn.Flatten() | |||
| self.dense = nn.Dense(2048, 50, activation='sigmoid') | |||
| self.sigmoid = nn.Sigmoid() | |||
| def construct(self, input_data): | |||
| """ | |||
| Music Tagger CNN | |||
| """ | |||
| x = self.conv1(input_data) | |||
| x = self.bn1(x) | |||
| x = self.Act(x) | |||
| x = self.pool1(x) | |||
| x = self.conv2(x) | |||
| x = self.bn2(x) | |||
| x = self.Act(x) | |||
| x = self.pool2(x) | |||
| x = self.conv3(x) | |||
| x = self.bn3(x) | |||
| x = self.Act(x) | |||
| x = self.pool3(x) | |||
| x = self.conv4(x) | |||
| x = self.bn4(x) | |||
| x = self.Act(x) | |||
| x = self.poolreduce(x, (2, 3)) | |||
| x = self.flatten(x) | |||
| x = self.dense(x) | |||
| return x | |||
| @@ -0,0 +1,226 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| '''python dataset.py''' | |||
| import os | |||
| import argparse | |||
| import pandas as pd | |||
| import numpy as np | |||
| import librosa | |||
| from mindspore.mindrecord import FileWriter | |||
| from mindspore import context | |||
| from src.config import data_cfg as cfg | |||
| def compute_melgram(audio_path, save_path='', filename='', save_npy=True): | |||
| """ | |||
| extract melgram feature from the audio and save as numpy array | |||
| Args: | |||
| audio_path (str): path to the audio clip. | |||
| save_path (str): path to save the numpy array. | |||
| filename (str): filename of the audio clip. | |||
| Returns: | |||
| numpy array. | |||
| """ | |||
| SR = 12000 | |||
| N_FFT = 512 | |||
| N_MELS = 96 | |||
| HOP_LEN = 256 | |||
| DURA = 29.12 # to make it 1366 frame.. | |||
| src, _ = librosa.load(audio_path, sr=SR) # whole signal | |||
| n_sample = src.shape[0] | |||
| n_sample_fit = int(DURA * SR) | |||
| if n_sample < n_sample_fit: # if too short | |||
| src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,)))) | |||
| elif n_sample > n_sample_fit: # if too long | |||
| src = src[(n_sample - n_sample_fit) // 2:(n_sample + n_sample_fit) // | |||
| 2] | |||
| logam = librosa.core.amplitude_to_db | |||
| melgram = librosa.feature.melspectrogram | |||
| ret = logam( | |||
| melgram(y=src, sr=SR, hop_length=HOP_LEN, n_fft=N_FFT, n_mels=N_MELS)) | |||
| ret = ret[np.newaxis, np.newaxis, :] | |||
| if save_npy: | |||
| save_path = save_path + filename[:-4] + '.npy' | |||
| np.save(save_path, ret) | |||
| return ret | |||
| def get_data(features_data, labels_data): | |||
| data_list = [] | |||
| for i, (label, feature) in enumerate(zip(labels_data, features_data)): | |||
| data_json = {"id": i, "feature": feature, "label": label} | |||
| data_list.append(data_json) | |||
| return data_list | |||
| def convert(s): | |||
| if s.isdigit(): | |||
| return int(s) | |||
| return s | |||
| def GetLabel(info_path, info_name): | |||
| """ | |||
| separate dataset into training set and validation set | |||
| Args: | |||
| info_path (str): path to the information file. | |||
| info_name (str): name of the information file. | |||
| """ | |||
| T = [] | |||
| with open(info_path + '/' + info_name, 'rb') as info: | |||
| data = info.readline() | |||
| while data: | |||
| T.append([ | |||
| convert(i[1:-1]) | |||
| for i in data.strip().decode('utf-8').split("\t") | |||
| ]) | |||
| data = info.readline() | |||
| annotation = pd.DataFrame(T[1:], columns=T[0]) | |||
| count = [] | |||
| for i in annotation.columns[1:-2]: | |||
| count.append([annotation[i].sum() / len(annotation), i]) | |||
| count = sorted(count) | |||
| full_label = [] | |||
| for i in count[-50:]: | |||
| full_label.append(i[1]) | |||
| out = [] | |||
| for i in T[1:]: | |||
| index = [k for k, x in enumerate(i) if x == 1] | |||
| label = [T[0][k] for k in index] | |||
| L = [str(0) for k in range(50)] | |||
| L.append(i[-1]) | |||
| for j in label: | |||
| if j in full_label: | |||
| ind = full_label.index(j) | |||
| L[ind] = '1' | |||
| out.append(L) | |||
| out = np.array(out) | |||
| Train = [] | |||
| Val = [] | |||
| for i in out: | |||
| if np.random.rand() > 0.2: | |||
| Train.append(i) | |||
| else: | |||
| Val.append(i) | |||
| np.savetxt("{}/music_tagging_train_tmp.csv".format(info_path), | |||
| np.array(Train), | |||
| fmt='%s', | |||
| delimiter=',') | |||
| np.savetxt("{}/music_tagging_val_tmp.csv".format(info_path), | |||
| np.array(Val), | |||
| fmt='%s', | |||
| delimiter=',') | |||
| def generator_md(info_name, file_path, num_classes): | |||
| """ | |||
| generate numpy array from features of all audio clips | |||
| Args: | |||
| info_path (str): path to the information file. | |||
| file_path (str): path to the npy files. | |||
| Returns: | |||
| 2 numpy array. | |||
| """ | |||
| df = pd.read_csv(info_name, header=None) | |||
| df.columns = [str(i) for i in range(num_classes)] + ["mp3_path"] | |||
| data = [] | |||
| label = [] | |||
| for i in range(len(df)): | |||
| try: | |||
| data.append( | |||
| np.load(file_path + df.mp3_path.values[i][:-4] + | |||
| '.npy').reshape(1, 96, 1366)) | |||
| label.append(np.array(df[df.columns[:-1]][i:i + 1])[0]) | |||
| except FileNotFoundError: | |||
| print("Exception occurred in generator_md.") | |||
| return np.array(data), np.array(label, dtype=np.int32) | |||
| def convert_to_mindrecord(info_name, file_path, store_path, mr_name, | |||
| num_classes): | |||
| """ convert dataset to mindrecord """ | |||
| num_shard = 4 | |||
| data, label = generator_md(info_name, file_path, num_classes) | |||
| schema_json = { | |||
| "id": { | |||
| "type": "int32" | |||
| }, | |||
| "feature": { | |||
| "type": "float32", | |||
| "shape": [1, 96, 1366] | |||
| }, | |||
| "label": { | |||
| "type": "int32", | |||
| "shape": [num_classes] | |||
| } | |||
| } | |||
| writer = FileWriter( | |||
| os.path.join(store_path, '{}.mindrecord'.format(mr_name)), num_shard) | |||
| datax = get_data(data, label) | |||
| writer.add_schema(schema_json, "music_tagger_schema") | |||
| writer.add_index(["id"]) | |||
| writer.write_raw_data(datax) | |||
| writer.commit() | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='get feature') | |||
| parser.add_argument('--device_id', | |||
| type=int, | |||
| help='device ID', | |||
| default=None) | |||
| args = parser.parse_args() | |||
| if cfg.get_npy: | |||
| GetLabel(cfg.info_path, cfg.info_name) | |||
| dirname = os.listdir(cfg.audio_path) | |||
| for d in dirname: | |||
| file_name = os.listdir("{}/{}".format(cfg.audio_path, d)) | |||
| if not os.path.isdir("{}/{}".format(cfg.npy_path, d)): | |||
| os.mkdir("{}/{}".format(cfg.npy_path, d)) | |||
| for f in file_name: | |||
| compute_melgram("{}/{}/{}".format(cfg.audio_path, d, f), | |||
| "{}/{}/".format(cfg.npy_path, d), f) | |||
| if cfg.get_mindrecord: | |||
| if args.device_id is not None: | |||
| context.set_context(device_target='Ascend', | |||
| mode=context.GRAPH_MODE, | |||
| device_id=args.device_id) | |||
| else: | |||
| context.set_context(device_target='Ascend', | |||
| mode=context.GRAPH_MODE, | |||
| device_id=cfg.device_id) | |||
| for cmn in cfg.mr_nam: | |||
| if cmn in ['train', 'val']: | |||
| convert_to_mindrecord('music_tagging_{}_tmp.csv'.format(cmn), | |||
| cfg.npy_path, cfg.mr_path, cmn, | |||
| cfg.num_classes) | |||
| @@ -0,0 +1,50 @@ | |||
| choral | |||
| female voice | |||
| metal | |||
| country | |||
| weird | |||
| no voice | |||
| cello | |||
| harp | |||
| beats | |||
| female vocal | |||
| male voice | |||
| dance | |||
| new age | |||
| voice | |||
| choir | |||
| classic | |||
| man | |||
| solo | |||
| sitar | |||
| soft | |||
| no vocal | |||
| pop | |||
| male vocal | |||
| woman | |||
| flute | |||
| quiet | |||
| loud | |||
| harpsichord | |||
| no vocals | |||
| vocals | |||
| singing | |||
| male | |||
| opera | |||
| indian | |||
| female | |||
| synth | |||
| vocal | |||
| violin | |||
| beat | |||
| ambient | |||
| piano | |||
| fast | |||
| rock | |||
| electronic | |||
| drums | |||
| strings | |||
| techno | |||
| slow | |||
| classical | |||
| guitar | |||
| @@ -0,0 +1,109 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ''' | |||
| ##############train models################# | |||
| python train.py | |||
| ''' | |||
| import argparse | |||
| from mindspore import context, nn | |||
| from mindspore.train import Model | |||
| from mindspore.common import set_seed | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from src.dataset import create_dataset | |||
| from src.musictagger import MusicTaggerCNN | |||
| from src.loss import BCELoss | |||
| from src.config import music_cfg as cfg | |||
| def train(model, dataset_direct, filename, columns_list, num_consumer=4, | |||
| batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50, | |||
| prefix="model", directory='./'): | |||
| """ | |||
| train network | |||
| """ | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, | |||
| keep_checkpoint_max=keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix=prefix, | |||
| directory=directory, | |||
| config=config_ck) | |||
| data_train = create_dataset(dataset_direct, filename, batch, columns_list, | |||
| num_consumer) | |||
| model.train(epoch, | |||
| data_train, | |||
| callbacks=[ | |||
| ckpoint_cb, | |||
| LossMonitor(per_print_times=181), | |||
| TimeMonitor() | |||
| ], | |||
| dataset_sink_mode=True) | |||
| if __name__ == "__main__": | |||
| set_seed(1) | |||
| parser = argparse.ArgumentParser(description='Train model') | |||
| parser.add_argument('--device_id', | |||
| type=int, | |||
| help='device ID', | |||
| default=None) | |||
| args = parser.parse_args() | |||
| if args.device_id is not None: | |||
| context.set_context(device_target='Ascend', | |||
| mode=context.GRAPH_MODE, | |||
| device_id=args.device_id) | |||
| else: | |||
| context.set_context(device_target='Ascend', | |||
| mode=context.GRAPH_MODE, | |||
| device_id=cfg.device_id) | |||
| context.set_context(enable_auto_mixed_precision=cfg.mixed_precision) | |||
| network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], | |||
| kernel_size=[3, 3, 3, 3, 3], | |||
| padding=[0] * 5, | |||
| maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], | |||
| has_bias=True) | |||
| if cfg.pre_trained: | |||
| param_dict = load_checkpoint(cfg.checkpoint_path + '/' + | |||
| cfg.model_name) | |||
| load_param_into_net(network, param_dict) | |||
| net_loss = BCELoss() | |||
| network.set_train(True) | |||
| net_opt = nn.Adam(params=network.trainable_params(), | |||
| learning_rate=cfg.lr, | |||
| loss_scale=cfg.loss_scale) | |||
| loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale, | |||
| drop_overflow_update=False) | |||
| net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager) | |||
| train(model=net_model, | |||
| dataset_direct=cfg.data_dir, | |||
| filename=cfg.train_filename, | |||
| columns_list=['feature', 'label'], | |||
| num_consumer=cfg.num_consumer, | |||
| batch=cfg.batch_size, | |||
| epoch=cfg.epoch_size, | |||
| save_checkpoint_steps=cfg.save_step, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max, | |||
| prefix=cfg.prefix, | |||
| directory=cfg.checkpoint_path + "_{}".format(cfg.device_id)) | |||
| print("train success") | |||