From f4ccfd4380b559adf2355fb74f6323fe5176d3bb Mon Sep 17 00:00:00 2001 From: jzg Date: Wed, 23 Sep 2020 10:37:04 +0800 Subject: [PATCH] add music auto tagging network. --- .../audio/music_auto_tagging/README.md | 203 ++++++++++++++++ .../official/audio/music_auto_tagging/eval.py | 137 +++++++++++ .../audio/music_auto_tagging/export.py | 40 ++++ .../music_auto_tagging/scripts/run_eval.sh | 18 ++ .../scripts/run_process_data.sh | 18 ++ .../music_auto_tagging/scripts/run_train.sh | 18 ++ .../audio/music_auto_tagging/src/__init__.py | 23 ++ .../audio/music_auto_tagging/src/config.py | 53 ++++ .../audio/music_auto_tagging/src/dataset.py | 30 +++ .../audio/music_auto_tagging/src/loss.py | 41 ++++ .../music_auto_tagging/src/musictagger.py | 83 +++++++ .../src/pre_process_data.py | 226 ++++++++++++++++++ .../audio/music_auto_tagging/src/tag.txt | 50 ++++ .../audio/music_auto_tagging/train.py | 109 +++++++++ 14 files changed, 1049 insertions(+) create mode 100644 model_zoo/official/audio/music_auto_tagging/README.md create mode 100644 model_zoo/official/audio/music_auto_tagging/eval.py create mode 100644 model_zoo/official/audio/music_auto_tagging/export.py create mode 100644 model_zoo/official/audio/music_auto_tagging/scripts/run_eval.sh create mode 100644 model_zoo/official/audio/music_auto_tagging/scripts/run_process_data.sh create mode 100644 model_zoo/official/audio/music_auto_tagging/scripts/run_train.sh create mode 100644 model_zoo/official/audio/music_auto_tagging/src/__init__.py create mode 100644 model_zoo/official/audio/music_auto_tagging/src/config.py create mode 100644 model_zoo/official/audio/music_auto_tagging/src/dataset.py create mode 100644 model_zoo/official/audio/music_auto_tagging/src/loss.py create mode 100644 model_zoo/official/audio/music_auto_tagging/src/musictagger.py create mode 100644 model_zoo/official/audio/music_auto_tagging/src/pre_process_data.py create mode 100644 model_zoo/official/audio/music_auto_tagging/src/tag.txt create mode 100644 model_zoo/official/audio/music_auto_tagging/train.py diff --git a/model_zoo/official/audio/music_auto_tagging/README.md b/model_zoo/official/audio/music_auto_tagging/README.md new file mode 100644 index 0000000000..b175b768ca --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/README.md @@ -0,0 +1,203 @@ +# Contents + +- [Music Auto Tagging Description](#fcn-4-description) +- [Model Architecture](#model-architecture) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + + +# [Music Auto Tagging Description](#contents) + +This repository provides a script and recipe to train the Music Auto Tagging model to achieve state-of-the-art accuracy. + +[Paper](https://arxiv.org/abs/1606.00298): `"Keunwoo Choi, George Fazekas, and Mark Sandler, “Automatic tagging using deep convolutional neural networks,” in International Society of Music Information Retrieval Conference. ISMIR, 2016." + + +# [Model Architecture](#contents) + +Music Auto Tagging is a convolutional neural network architecture, its name Music Auto Tagging comes from the fact that it has 4 layers. Its layers consists of Convolutional layers, Max Pooling layers, Activation layers, Fully connected layers. + +# [Features](#contents) + +## Mixed Precision + +The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. +For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. + + +# [Environment Requirements](#contents) + +- Hardware(Ascend + - If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + + +# [Quick Start](#contents) + +After installing MindSpore via the official website, you can start training and evaluation as follows: + +### 1. Download and preprocess the dataset + +1. down load the classification dataset (for instance, MagnaTagATune Dataset, Million Song Dataset, etc) +2. Extract the dataset +3. The information file of each clip should contain the label and path. Please refer to the annotations_final.csv in MagnaTagATune Dataset. +4. The provided pre-processing script use MagnaTagATune Dataset as an example. Please modify the code accprding to your own need. + +### 2. setup parameters (src/config.py) + +### 3. Train + +after having your dataset, first convert the audio clip into mindrecord dataset by using the following codes +```shell +python pre_process_data.py --device_id 0 +``` + +Then, you can start training the model by using the following codes +```shell +SLOG_PRINT_TO_STDOUT=1 python train.py --device_id 0 +``` + +### 4. Test + +Then you can test your model +```shell +SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0 +``` + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +``` +├── model_zoo + ├── README.md // descriptions about all the models + ├── music_auto_tagging + ├── README.md // descriptions about googlenet + ├── scripts + │ ├──run_train.sh // shell script for distributed on Ascend + │ ├──run_eval.sh // shell script for evaluation on Ascend + │ ├──run_process_data.sh // shell script for convert audio clips to mindrecord + ├── src + │ ├──dataset.py // creating dataset + │ ├──pre_process_data.py // pre-process dataset + │ ├──musictagger.py // googlenet architecture + │ ├──config.py // parameter configuration + │ ├──loss.py // loss function + │ ├──tag.txt // tag for each number + ├── train.py // training script + ├── eval.py // evaluation script + ├── export.py // export model in air format +``` + +## [Script Parameters](#contents) + +Parameters for both training and evaluation can be set in config.py + +- config for Music Auto tagging + + ```python + + 'num_classes': 50 # number of tagging classes + 'num_consumer': 4 # file number for mindrecord + 'get_npy': 1 # mode for converting to npy, default 1 in this case + 'get_mindrecord': 1 # mode for converting npy file into mindrecord file,default 1 in this case + 'audio_path': "/dev/data/Music_Tagger_Data/fea/" # path to audio clips + 'npy_path': "/dev/data/Music_Tagger_Data/fea/" # path to numpy + 'info_path': "/dev/data/Music_Tagger_Data/fea/" # path to info_name, which provide the label of each audio clips + 'info_name': 'annotations_final.csv' # info_name + 'device_target': 'Ascend' # device running the program + 'device_id': 0 # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training + 'mr_path': '/dev/data/Music_Tagger_Data/fea/' # path to mindrecord + 'mr_name': ['train', 'val'] # mindrecord name + + 'pre_trained': False # whether training based on the pre-trained model + 'lr': 0.0005 # learning rate + 'batch_size': 32 # training batch size + 'epoch_size': 10 # total training epochs + 'loss_scale': 1024.0 # loss scale + 'num_consumer': 4 # file number for mindrecord + 'mixed_precision': False # if use mix precision calculation + 'train_filename': 'train.mindrecord0' # file name of the train mindrecord data + 'val_filename': 'val.mindrecord0' # file name of the evaluation mindrecord data + 'data_dir': '/dev/data/Music_Tagger_Data/fea/' # directory of mindrecord data + 'device_target': 'Ascend' # device running the program + 'device_id': 0, # device ID used to train or evaluate the dataset. Ignore it when you use run_train.sh for distributed training + 'keep_checkpoint_max': 10, # only keep the last keep_checkpoint_max checkpoint + 'save_step': 2000, # steps for saving checkpoint + 'checkpoint_path': '/dev/data/Music_Tagger_Data/model/', # the absolute full path to save the checkpoint file + 'prefix': 'MusicTagger', # prefix of checkpoint + 'model_name': 'MusicTagger_3-50_543.ckpt', # checkpoint name + ``` + + +## [Training Process](#contents) + +### Training + +- running on Ascend + + ``` + python train.py > train.log 2>&1 & + ``` + + The python command above will run in the background, you can view the results through the file `train.log`. + + After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: + + ``` + # grep "loss is " train.log + epoch: 1 step: 100, loss is 0.23264095 + epoch: 1 step: 200, loss is 0.2013525 + ... + ``` + + The model checkpoint will be saved in the set directory. + +## [Evaluation Process](#contents) + +### Evaluation + + +# [Model Description](#contents) +## [Performance](#contents) + +### Evaluation Performance + +| Parameters | Ascend | +| -------------------------- | ----------------------------------------------------------- | +| Model Version | FCN-4 | +| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G | +| uploaded Date | 09/11/2020 (month/day/year) | +| MindSpore Version | r0.7.0 | +| Training Parameters | epoch=10, steps=534, batch_size = 32, lr=0.005 | +| Optimizer | Adam | +| Loss Function | Binary cross entropy | +| outputs | probability | +| Loss | AUC 0.909 | +| Speed | 1pc: 160 samples/sec; | +| Total time | 1pc: 20 mins; | +| Checkpoint for Fine tuning | 198.73M(.ckpt file) | +| Scripts | [music_auto_tagging script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/audio/music_auto_tagging) | + + + +# [ModelZoo Homepage](#contents) + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/audio/music_auto_tagging/eval.py b/model_zoo/official/audio/music_auto_tagging/eval.py new file mode 100644 index 0000000000..7f8b21fd14 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/eval.py @@ -0,0 +1,137 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +##############evaluate trained models################# +python eval.py +''' + +import argparse +import numpy as np +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.musictagger import MusicTaggerCNN +from src.config import music_cfg as cfg +from src.dataset import create_dataset + + +def calculate_auc(labels_list, preds_list): + """ + The AUC calculation function + Input: + labels_list: list of true label + preds_list: list of predicted label + Outputs + Float, means of AUC + """ + auc = [] + n_bins = labels_list.shape[0] // 2 + if labels_list.ndim == 1: + labels_list = labels_list.reshape(-1, 1) + preds_list = preds_list.reshape(-1, 1) + for i in range(labels_list.shape[1]): + labels = labels_list[:, i] + preds = preds_list[:, i] + postive_len = labels.sum() + negative_len = labels.shape[0] - postive_len + total_case = postive_len * negative_len + positive_histogram = np.zeros((n_bins)) + negative_histogram = np.zeros((n_bins)) + bin_width = 1.0 / n_bins + + for j, _ in enumerate(labels): + nth_bin = int(preds[j] // bin_width) + if labels[j]: + positive_histogram[nth_bin] = positive_histogram[nth_bin] + 1 + else: + negative_histogram[nth_bin] = negative_histogram[nth_bin] + 1 + + accumulated_negative = 0 + satisfied_pair = 0 + for k in range(n_bins): + satisfied_pair += ( + positive_histogram[k] * accumulated_negative + + positive_histogram[k] * negative_histogram[k] * 0.5) + accumulated_negative += negative_histogram[k] + auc.append(satisfied_pair / total_case) + + return np.mean(auc) + + +def val(net, data_dir, filename, num_consumer=4, batch=32): + """ + Validation function, estimate the performance of trained model + + Input: + net: the trained neural network + data_dir: path to the validation dataset + filename: name of the validation dataset + num_consumer: split number of validation dataset + batch: validation batch size + Outputs + Float, AUC + """ + data_train = create_dataset(data_dir, filename, 32, ['feature', 'label'], + num_consumer) + data_train = data_train.create_tuple_iterator() + res_pred = [] + res_true = [] + for data, label in data_train: + x = net(Tensor(data, dtype=mstype.float32)) + res_pred.append(x.asnumpy()) + res_true.append(label.asnumpy()) + res_pred = np.concatenate(res_pred, axis=0) + res_true = np.concatenate(res_true, axis=0) + auc = calculate_auc(res_true, res_pred) + return auc + + +def validation(net, model_path, data_dir, filename, num_consumer, batch): + param_dict = load_checkpoint(model_path) + load_param_into_net(net, param_dict) + + auc = val(net, data_dir, filename, num_consumer, batch) + return auc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Evaluate model') + parser.add_argument('--device_id', + type=int, + help='device ID', + default=None) + args = parser.parse_args() + + if args.device_id is not None: + context.set_context(device_target=cfg.device_target, + mode=context.GRAPH_MODE, + device_id=args.device_id) + else: + context.set_context(device_target=cfg.device_target, + mode=context.GRAPH_MODE, + device_id=cfg.device_id) + + network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], + kernel_size=[3, 3, 3, 3, 3], + padding=[0] * 5, + maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], + has_bias=True) + network.set_train(False) + auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir, + cfg.val_filename, cfg.num_consumer, cfg.batch_size) + + print("=" * 10 + "Validation Peformance" + "=" * 10) + print("AUC: {:.5f}".format(auc_val)) diff --git a/model_zoo/official/audio/music_auto_tagging/export.py b/model_zoo/official/audio/music_auto_tagging/export.py new file mode 100644 index 0000000000..3e1b2da057 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/export.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +##############evaluate trained models################# +python export.py +''' + +import numpy as np +from mindspore.train.serialization import export +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.musictagger import MusicTaggerCNN +from src.config import music_cfg as cfg + +if __name__ == "__main__": + network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], + kernel_size=[3, 3, 3, 3, 3], + padding=[0] * 5, + maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], + has_bias=True) + param_dict = load_checkpoint(cfg.checkpoint_path + "/" + cfg.model_name) + load_param_into_net(network, param_dict) + input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32) + export(network, + Tensor(input_data), + filename="{}/{}.air".format(cfg.checkpoint_path, + cfg.model_name[:-5]), + file_format="AIR") diff --git a/model_zoo/official/audio/music_auto_tagging/scripts/run_eval.sh b/model_zoo/official/audio/music_auto_tagging/scripts/run_eval.sh new file mode 100644 index 0000000000..d402e7dd19 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/scripts/run_eval.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export SLOG_PRINT_TO_STDOUT=1 +python ../eval.py --device_id 0 \ No newline at end of file diff --git a/model_zoo/official/audio/music_auto_tagging/scripts/run_process_data.sh b/model_zoo/official/audio/music_auto_tagging/scripts/run_process_data.sh new file mode 100644 index 0000000000..57a95ec28e --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/scripts/run_process_data.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export SLOG_PRINT_TO_STDOUT=1 +python ../src/pre_process_data.py --device_id 0 diff --git a/model_zoo/official/audio/music_auto_tagging/scripts/run_train.sh b/model_zoo/official/audio/music_auto_tagging/scripts/run_train.sh new file mode 100644 index 0000000000..cde7a574aa --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/scripts/run_train.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export SLOG_PRINT_TO_STDOUT=1 +python ../train.py --device_id 0 \ No newline at end of file diff --git a/model_zoo/official/audio/music_auto_tagging/src/__init__.py b/model_zoo/official/audio/music_auto_tagging/src/__init__.py new file mode 100644 index 0000000000..62d5dab6cc --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/src/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +__init__.py +""" + +from . import musictagger +from . import loss +from . import dataset +from . import config +from . import pre_process_data diff --git a/model_zoo/official/audio/music_auto_tagging/src/config.py b/model_zoo/official/audio/music_auto_tagging/src/config.py new file mode 100644 index 0000000000..ed47598d60 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/src/config.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py, eval.py +""" +from easydict import EasyDict as edict + +data_cfg = edict({ + 'num_classes': 50, + 'num_consumer': 4, + 'get_npy': 1, + 'get_mindrecord': 1, + 'audio_path': "/dev/data/Music_Tagger_Data/fea/", + 'npy_path': "/dev/data/Music_Tagger_Data/fea/", + 'info_path': "/dev/data/Music_Tagger_Data/fea/", + 'info_name': 'annotations_final.csv', + 'device_target': 'Ascend', + 'device_id': 0, + 'mr_path': '/dev/data/Music_Tagger_Data/fea/', + 'mr_name': ['train', 'val'], +}) + +music_cfg = edict({ + 'pre_trained': False, + 'lr': 0.0005, + 'batch_size': 32, + 'epoch_size': 10, + 'loss_scale': 1024.0, + 'num_consumer': 4, + 'mixed_precision': False, + 'train_filename': 'train.mindrecord0', + 'val_filename': 'val.mindrecord0', + 'data_dir': '/dev/data/Music_Tagger_Data/fea/', + 'device_target': 'Ascend', + 'device_id': 0, + 'keep_checkpoint_max': 10, + 'save_step': 2000, + 'checkpoint_path': '/dev/data/Music_Tagger_Data/model', + 'prefix': 'MusicTagger', + 'model_name': 'MusicTagger_3-50_543.ckpt', +}) diff --git a/model_zoo/official/audio/music_auto_tagging/src/dataset.py b/model_zoo/official/audio/music_auto_tagging/src/dataset.py new file mode 100644 index 0000000000..14dfb7d033 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/src/dataset.py @@ -0,0 +1,30 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''python dataset.py''' + +import os +import mindspore.dataset as ds + + +def create_dataset(base_path, filename, batch_size, columns_list, + num_consumer): + """Create dataset""" + + path = os.path.join(base_path, filename) + dtrain = ds.MindDataset(path, columns_list, num_consumer) + dtrain = dtrain.shuffle(buffer_size=dtrain.get_dataset_size()) + dtrain = dtrain.batch(batch_size, drop_remainder=True) + + return dtrain diff --git a/model_zoo/official/audio/music_auto_tagging/src/loss.py b/model_zoo/official/audio/music_auto_tagging/src/loss.py new file mode 100644 index 0000000000..2463863e40 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/src/loss.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +define loss +""" +from mindspore import nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P + + + +class BCELoss(nn.Cell): + """ + BCELoss + """ + def __init__(self, record=None): + super(BCELoss, self).__init__(record) + self.sm_scalar = P.ScalarSummary() + self.cast = P.Cast() + self.record = record + self.weight = None + self.bce = P.BinaryCrossEntropy() + + def construct(self, input_data, target): + target = self.cast(target, mstype.float32) + loss = self.bce(input_data, target, self.weight) + if self.record: + self.sm_scalar("loss", loss) + return loss diff --git a/model_zoo/official/audio/music_auto_tagging/src/musictagger.py b/model_zoo/official/audio/music_auto_tagging/src/musictagger.py new file mode 100644 index 0000000000..4f02a36459 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/src/musictagger.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''model''' + +from mindspore import nn +from mindspore.ops import operations as P + + +class MusicTaggerCNN(nn.Cell): + """ + Music Tagger CNN + """ + def __init__(self, in_classes, kernel_size, padding, maxpool, has_bias): + super(MusicTaggerCNN, self).__init__() + self.in_classes = in_classes + self.kernel_size = kernel_size + self.maxpool = maxpool + self.padding = padding + self.has_bias = has_bias + # build model + self.conv1 = nn.Conv2d(self.in_classes[0], self.in_classes[1], + self.kernel_size[0]) + self.conv2 = nn.Conv2d(self.in_classes[1], self.in_classes[2], + self.kernel_size[1]) + self.conv3 = nn.Conv2d(self.in_classes[2], self.in_classes[3], + self.kernel_size[2]) + self.conv4 = nn.Conv2d(self.in_classes[3], self.in_classes[4], + self.kernel_size[3]) + + self.bn1 = nn.BatchNorm2d(self.in_classes[1]) + self.bn2 = nn.BatchNorm2d(self.in_classes[2]) + self.bn3 = nn.BatchNorm2d(self.in_classes[3]) + self.bn4 = nn.BatchNorm2d(self.in_classes[4]) + + self.pool1 = nn.MaxPool2d(maxpool[0], maxpool[0]) + self.pool2 = nn.MaxPool2d(maxpool[1], maxpool[1]) + self.pool3 = nn.MaxPool2d(maxpool[2], maxpool[2]) + self.pool4 = nn.MaxPool2d(maxpool[3], maxpool[3]) + self.poolreduce = P.ReduceMax(keep_dims=False) + self.Act = nn.ReLU() + self.flatten = nn.Flatten() + self.dense = nn.Dense(2048, 50, activation='sigmoid') + self.sigmoid = nn.Sigmoid() + + def construct(self, input_data): + """ + Music Tagger CNN + """ + x = self.conv1(input_data) + x = self.bn1(x) + x = self.Act(x) + x = self.pool1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.Act(x) + x = self.pool2(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.Act(x) + x = self.pool3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.Act(x) + x = self.poolreduce(x, (2, 3)) + x = self.flatten(x) + x = self.dense(x) + + return x diff --git a/model_zoo/official/audio/music_auto_tagging/src/pre_process_data.py b/model_zoo/official/audio/music_auto_tagging/src/pre_process_data.py new file mode 100644 index 0000000000..39f59dd01c --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/src/pre_process_data.py @@ -0,0 +1,226 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''python dataset.py''' + +import os +import argparse +import pandas as pd +import numpy as np +import librosa +from mindspore.mindrecord import FileWriter +from mindspore import context +from src.config import data_cfg as cfg + + +def compute_melgram(audio_path, save_path='', filename='', save_npy=True): + """ + extract melgram feature from the audio and save as numpy array + + Args: + audio_path (str): path to the audio clip. + save_path (str): path to save the numpy array. + filename (str): filename of the audio clip. + + Returns: + numpy array. + + """ + SR = 12000 + N_FFT = 512 + N_MELS = 96 + HOP_LEN = 256 + DURA = 29.12 # to make it 1366 frame.. + + src, _ = librosa.load(audio_path, sr=SR) # whole signal + n_sample = src.shape[0] + n_sample_fit = int(DURA * SR) + + if n_sample < n_sample_fit: # if too short + src = np.hstack((src, np.zeros((int(DURA * SR) - n_sample,)))) + elif n_sample > n_sample_fit: # if too long + src = src[(n_sample - n_sample_fit) // 2:(n_sample + n_sample_fit) // + 2] + logam = librosa.core.amplitude_to_db + melgram = librosa.feature.melspectrogram + ret = logam( + melgram(y=src, sr=SR, hop_length=HOP_LEN, n_fft=N_FFT, n_mels=N_MELS)) + ret = ret[np.newaxis, np.newaxis, :] + if save_npy: + + save_path = save_path + filename[:-4] + '.npy' + np.save(save_path, ret) + return ret + + +def get_data(features_data, labels_data): + data_list = [] + for i, (label, feature) in enumerate(zip(labels_data, features_data)): + data_json = {"id": i, "feature": feature, "label": label} + data_list.append(data_json) + return data_list + + +def convert(s): + if s.isdigit(): + return int(s) + return s + + +def GetLabel(info_path, info_name): + """ + separate dataset into training set and validation set + + Args: + info_path (str): path to the information file. + info_name (str): name of the information file. + + """ + T = [] + with open(info_path + '/' + info_name, 'rb') as info: + data = info.readline() + while data: + T.append([ + convert(i[1:-1]) + for i in data.strip().decode('utf-8').split("\t") + ]) + data = info.readline() + + annotation = pd.DataFrame(T[1:], columns=T[0]) + count = [] + for i in annotation.columns[1:-2]: + count.append([annotation[i].sum() / len(annotation), i]) + count = sorted(count) + full_label = [] + for i in count[-50:]: + full_label.append(i[1]) + out = [] + for i in T[1:]: + index = [k for k, x in enumerate(i) if x == 1] + label = [T[0][k] for k in index] + L = [str(0) for k in range(50)] + L.append(i[-1]) + for j in label: + if j in full_label: + ind = full_label.index(j) + L[ind] = '1' + out.append(L) + out = np.array(out) + + Train = [] + Val = [] + + for i in out: + if np.random.rand() > 0.2: + Train.append(i) + else: + Val.append(i) + np.savetxt("{}/music_tagging_train_tmp.csv".format(info_path), + np.array(Train), + fmt='%s', + delimiter=',') + np.savetxt("{}/music_tagging_val_tmp.csv".format(info_path), + np.array(Val), + fmt='%s', + delimiter=',') + + +def generator_md(info_name, file_path, num_classes): + """ + generate numpy array from features of all audio clips + + Args: + info_path (str): path to the information file. + file_path (str): path to the npy files. + + Returns: + 2 numpy array. + + """ + df = pd.read_csv(info_name, header=None) + df.columns = [str(i) for i in range(num_classes)] + ["mp3_path"] + data = [] + label = [] + for i in range(len(df)): + try: + data.append( + np.load(file_path + df.mp3_path.values[i][:-4] + + '.npy').reshape(1, 96, 1366)) + label.append(np.array(df[df.columns[:-1]][i:i + 1])[0]) + except FileNotFoundError: + print("Exception occurred in generator_md.") + return np.array(data), np.array(label, dtype=np.int32) + + +def convert_to_mindrecord(info_name, file_path, store_path, mr_name, + num_classes): + """ convert dataset to mindrecord """ + num_shard = 4 + data, label = generator_md(info_name, file_path, num_classes) + schema_json = { + "id": { + "type": "int32" + }, + "feature": { + "type": "float32", + "shape": [1, 96, 1366] + }, + "label": { + "type": "int32", + "shape": [num_classes] + } + } + + writer = FileWriter( + os.path.join(store_path, '{}.mindrecord'.format(mr_name)), num_shard) + datax = get_data(data, label) + writer.add_schema(schema_json, "music_tagger_schema") + writer.add_index(["id"]) + writer.write_raw_data(datax) + writer.commit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='get feature') + parser.add_argument('--device_id', + type=int, + help='device ID', + default=None) + args = parser.parse_args() + + if cfg.get_npy: + GetLabel(cfg.info_path, cfg.info_name) + dirname = os.listdir(cfg.audio_path) + for d in dirname: + file_name = os.listdir("{}/{}".format(cfg.audio_path, d)) + if not os.path.isdir("{}/{}".format(cfg.npy_path, d)): + os.mkdir("{}/{}".format(cfg.npy_path, d)) + for f in file_name: + compute_melgram("{}/{}/{}".format(cfg.audio_path, d, f), + "{}/{}/".format(cfg.npy_path, d), f) + + if cfg.get_mindrecord: + if args.device_id is not None: + context.set_context(device_target='Ascend', + mode=context.GRAPH_MODE, + device_id=args.device_id) + else: + context.set_context(device_target='Ascend', + mode=context.GRAPH_MODE, + device_id=cfg.device_id) + for cmn in cfg.mr_nam: + if cmn in ['train', 'val']: + convert_to_mindrecord('music_tagging_{}_tmp.csv'.format(cmn), + cfg.npy_path, cfg.mr_path, cmn, + cfg.num_classes) diff --git a/model_zoo/official/audio/music_auto_tagging/src/tag.txt b/model_zoo/official/audio/music_auto_tagging/src/tag.txt new file mode 100644 index 0000000000..2926c1a64d --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/src/tag.txt @@ -0,0 +1,50 @@ +choral +female voice +metal +country +weird +no voice +cello +harp +beats +female vocal +male voice +dance +new age +voice +choir +classic +man +solo +sitar +soft +no vocal +pop +male vocal +woman +flute +quiet +loud +harpsichord +no vocals +vocals +singing +male +opera +indian +female +synth +vocal +violin +beat +ambient +piano +fast +rock +electronic +drums +strings +techno +slow +classical +guitar diff --git a/model_zoo/official/audio/music_auto_tagging/train.py b/model_zoo/official/audio/music_auto_tagging/train.py new file mode 100644 index 0000000000..63ef965c25 --- /dev/null +++ b/model_zoo/official/audio/music_auto_tagging/train.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +##############train models################# +python train.py +''' +import argparse +from mindspore import context, nn +from mindspore.train import Model +from mindspore.common import set_seed +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from src.dataset import create_dataset +from src.musictagger import MusicTaggerCNN +from src.loss import BCELoss +from src.config import music_cfg as cfg + +def train(model, dataset_direct, filename, columns_list, num_consumer=4, + batch=16, epoch=50, save_checkpoint_steps=2172, keep_checkpoint_max=50, + prefix="model", directory='./'): + """ + train network + """ + config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, + keep_checkpoint_max=keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=prefix, + directory=directory, + config=config_ck) + data_train = create_dataset(dataset_direct, filename, batch, columns_list, + num_consumer) + + + model.train(epoch, + data_train, + callbacks=[ + ckpoint_cb, + LossMonitor(per_print_times=181), + TimeMonitor() + ], + dataset_sink_mode=True) + + +if __name__ == "__main__": + set_seed(1) + parser = argparse.ArgumentParser(description='Train model') + parser.add_argument('--device_id', + type=int, + help='device ID', + default=None) + + args = parser.parse_args() + + if args.device_id is not None: + context.set_context(device_target='Ascend', + mode=context.GRAPH_MODE, + device_id=args.device_id) + else: + context.set_context(device_target='Ascend', + mode=context.GRAPH_MODE, + device_id=cfg.device_id) + + context.set_context(enable_auto_mixed_precision=cfg.mixed_precision) + network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048], + kernel_size=[3, 3, 3, 3, 3], + padding=[0] * 5, + maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)], + has_bias=True) + + if cfg.pre_trained: + param_dict = load_checkpoint(cfg.checkpoint_path + '/' + + cfg.model_name) + load_param_into_net(network, param_dict) + + net_loss = BCELoss() + + network.set_train(True) + net_opt = nn.Adam(params=network.trainable_params(), + learning_rate=cfg.lr, + loss_scale=cfg.loss_scale) + + loss_scale_manager = FixedLossScaleManager(loss_scale=cfg.loss_scale, + drop_overflow_update=False) + net_model = Model(network, net_loss, net_opt, loss_scale_manager=loss_scale_manager) + + train(model=net_model, + dataset_direct=cfg.data_dir, + filename=cfg.train_filename, + columns_list=['feature', 'label'], + num_consumer=cfg.num_consumer, + batch=cfg.batch_size, + epoch=cfg.epoch_size, + save_checkpoint_steps=cfg.save_step, + keep_checkpoint_max=cfg.keep_checkpoint_max, + prefix=cfg.prefix, + directory=cfg.checkpoint_path + "_{}".format(cfg.device_id)) + print("train success")