|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- '''
- ##############evaluate trained models#################
- python eval.py
- '''
-
- import argparse
- import numpy as np
- import mindspore.common.dtype as mstype
- from mindspore import context
- from mindspore import Tensor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from src.musictagger import MusicTaggerCNN
- from src.config import music_cfg as cfg
- from src.dataset import create_dataset
-
-
- def calculate_auc(labels_list, preds_list):
- """
- The AUC calculation function
- Input:
- labels_list: list of true label
- preds_list: list of predicted label
- Outputs
- Float, means of AUC
- """
- auc = []
- n_bins = labels_list.shape[0] // 2
- if labels_list.ndim == 1:
- labels_list = labels_list.reshape(-1, 1)
- preds_list = preds_list.reshape(-1, 1)
- for i in range(labels_list.shape[1]):
- labels = labels_list[:, i]
- preds = preds_list[:, i]
- postive_len = labels.sum()
- negative_len = labels.shape[0] - postive_len
- total_case = postive_len * negative_len
- positive_histogram = np.zeros((n_bins))
- negative_histogram = np.zeros((n_bins))
- bin_width = 1.0 / n_bins
-
- for j, _ in enumerate(labels):
- nth_bin = int(preds[j] // bin_width)
- if labels[j]:
- positive_histogram[nth_bin] = positive_histogram[nth_bin] + 1
- else:
- negative_histogram[nth_bin] = negative_histogram[nth_bin] + 1
-
- accumulated_negative = 0
- satisfied_pair = 0
- for k in range(n_bins):
- satisfied_pair += (
- positive_histogram[k] * accumulated_negative +
- positive_histogram[k] * negative_histogram[k] * 0.5)
- accumulated_negative += negative_histogram[k]
- auc.append(satisfied_pair / total_case)
-
- return np.mean(auc)
-
-
- def val(net, data_dir, filename, num_consumer=4, batch=32):
- """
- Validation function, estimate the performance of trained model
-
- Input:
- net: the trained neural network
- data_dir: path to the validation dataset
- filename: name of the validation dataset
- num_consumer: split number of validation dataset
- batch: validation batch size
- Outputs
- Float, AUC
- """
- data_train = create_dataset(data_dir, filename, 32, ['feature', 'label'],
- num_consumer)
- data_train = data_train.create_tuple_iterator()
- res_pred = []
- res_true = []
- for data, label in data_train:
- x = net(Tensor(data, dtype=mstype.float32))
- res_pred.append(x.asnumpy())
- res_true.append(label.asnumpy())
- res_pred = np.concatenate(res_pred, axis=0)
- res_true = np.concatenate(res_true, axis=0)
- auc = calculate_auc(res_true, res_pred)
- return auc
-
-
- def validation(net, model_path, data_dir, filename, num_consumer, batch):
- param_dict = load_checkpoint(model_path)
- load_param_into_net(net, param_dict)
-
- auc = val(net, data_dir, filename, num_consumer, batch)
- return auc
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='Evaluate model')
- parser.add_argument('--device_id',
- type=int,
- help='device ID',
- default=None)
- args = parser.parse_args()
-
- if args.device_id is not None:
- context.set_context(device_target=cfg.device_target,
- mode=context.GRAPH_MODE,
- device_id=args.device_id)
- else:
- context.set_context(device_target=cfg.device_target,
- mode=context.GRAPH_MODE,
- device_id=cfg.device_id)
-
- network = MusicTaggerCNN(in_classes=[1, 128, 384, 768, 2048],
- kernel_size=[3, 3, 3, 3, 3],
- padding=[0] * 5,
- maxpool=[(2, 4), (4, 5), (3, 8), (4, 8)],
- has_bias=True)
- network.set_train(False)
- auc_val = validation(network, cfg.checkpoint_path + "/" + cfg.model_name, cfg.data_dir,
- cfg.val_filename, cfg.num_consumer, cfg.batch_size)
-
- print("=" * 10 + "Validation Performance" + "=" * 10)
- print("AUC: {:.5f}".format(auc_val))
|