From: @zhao_ting_v Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34pull/14528/MERGE
| @@ -155,6 +155,7 @@ python eval.py --net=[resnet50|resnet101] --dataset=[cifar10|imagenet2012] --dat | |||
| ├── src | |||
| ├── config.py # parameter configuration | |||
| ├── dataset.py # data preprocessing | |||
| ├─ eval_callback.py # evaluation callback while training | |||
| ├── CrossEntropySmooth.py # loss definition for ImageNet2012 dataset | |||
| ├── lr_generator.py # generate learning rate for each step | |||
| ├── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50 | |||
| @@ -323,6 +324,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen | |||
| bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||
| ``` | |||
| #### Evaluation while training | |||
| You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. | |||
| ### Result | |||
| - Training ResNet18 with CIFAR-10 dataset | |||
| @@ -143,7 +143,8 @@ bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] | |||
| ├── src | |||
| ├── config.py # 参数配置 | |||
| ├── dataset.py # 数据预处理 | |||
| ├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义 | |||
| ├─ eval_callback.py # 训练时推理回调函数 | |||
| ├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义 | |||
| ├── lr_generator.py # 生成每个步骤的学习率 | |||
| └── resnet.py # ResNet骨干网络,包括ResNet50、ResNet101和SE-ResNet50 | |||
| ├── eval.py # 评估网络 | |||
| @@ -297,6 +298,10 @@ bash run_parameter_server_train.sh [resnet18|resnet50|resnet101] [cifar10|imagen | |||
| bash run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) | |||
| ``` | |||
| #### 训练时推理 | |||
| 训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` 。 | |||
| ### 结果 | |||
| - 使用CIFAR-10数据集训练ResNet18 | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation callback when training""" | |||
| import os | |||
| import stat | |||
| from mindspore import save_checkpoint | |||
| from mindspore import log as logger | |||
| from mindspore.train.callback import Callback | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Evaluation callback when training. | |||
| Args: | |||
| eval_function (function): evaluation function. | |||
| eval_param_dict (dict): evaluation parameters' configure dict. | |||
| interval (int): run evaluation interval, default is 1. | |||
| eval_start_epoch (int): evaluation start epoch, default is 1. | |||
| save_best_ckpt (bool): Whether to save best checkpoint, default is True. | |||
| besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | |||
| metrics_name (str): evaluation metrics name, default is `acc`. | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> EvalCallBack(eval_function, eval_param_dict) | |||
| """ | |||
| def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | |||
| ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | |||
| super(EvalCallBack, self).__init__() | |||
| self.eval_param_dict = eval_param_dict | |||
| self.eval_function = eval_function | |||
| self.eval_start_epoch = eval_start_epoch | |||
| if interval < 1: | |||
| raise ValueError("interval should >= 1.") | |||
| self.interval = interval | |||
| self.save_best_ckpt = save_best_ckpt | |||
| self.best_res = 0 | |||
| self.best_epoch = 0 | |||
| if not os.path.isdir(ckpt_directory): | |||
| os.makedirs(ckpt_directory) | |||
| self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | |||
| self.metrics_name = metrics_name | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def epoch_end(self, run_context): | |||
| """Callback when epoch end.""" | |||
| cb_params = run_context.original_args() | |||
| cur_epoch = cb_params.cur_epoch_num | |||
| if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | |||
| res = self.eval_function(self.eval_param_dict) | |||
| print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | |||
| if res >= self.best_res: | |||
| self.best_res = res | |||
| self.best_epoch = cur_epoch | |||
| print("update best result: {}".format(res), flush=True) | |||
| if self.save_best_ckpt: | |||
| if os.path.exists(self.bast_ckpt_path): | |||
| self.remove_ckpoint_file(self.bast_ckpt_path) | |||
| save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | |||
| print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | |||
| def end(self, run_context): | |||
| print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | |||
| self.best_res, | |||
| self.best_epoch), flush=True) | |||
| @@ -0,0 +1,132 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """evaluation metric.""" | |||
| from mindspore.communication.management import GlobalComm | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| class ClassifyCorrectCell(nn.Cell): | |||
| r""" | |||
| Cell that returns correct count of the prediction in classification network. | |||
| This Cell accepts a network as arguments. | |||
| It returns orrect count of the prediction to calculate the metrics. | |||
| Args: | |||
| network (Cell): The network Cell. | |||
| Inputs: | |||
| - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
| - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||
| Outputs: | |||
| Tuple, containing a scalar correct count of the prediction | |||
| Examples: | |||
| >>> # For a defined network Net without loss function | |||
| >>> net = Net() | |||
| >>> eval_net = nn.ClassifyCorrectCell(net) | |||
| """ | |||
| def __init__(self, network): | |||
| super(ClassifyCorrectCell, self).__init__(auto_prefix=False) | |||
| self._network = network | |||
| self.argmax = P.Argmax() | |||
| self.equal = P.Equal() | |||
| self.cast = P.Cast() | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | |||
| def construct(self, data, label): | |||
| outputs = self._network(data) | |||
| y_pred = self.argmax(outputs) | |||
| y_pred = self.cast(y_pred, mstype.int32) | |||
| y_correct = self.equal(y_pred, label) | |||
| y_correct = self.cast(y_correct, mstype.float32) | |||
| y_correct = self.reduce_sum(y_correct) | |||
| total_correct = self.allreduce(y_correct) | |||
| return (total_correct,) | |||
| class DistAccuracy(nn.Metric): | |||
| r""" | |||
| Calculates the accuracy for classification data in distributed mode. | |||
| The accuracy class creates two local variables, correct number and total number that are used to compute the | |||
| frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an | |||
| idempotent operation that simply divides correct number by total number. | |||
| .. math:: | |||
| \text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} | |||
| {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}} | |||
| Args: | |||
| eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label). | |||
| Examples: | |||
| >>> y_correct = Tensor(np.array([20])) | |||
| >>> metric = nn.DistAccuracy(batch_size=3, device_num=8) | |||
| >>> metric.clear() | |||
| >>> metric.update(y_correct) | |||
| >>> accuracy = metric.eval() | |||
| """ | |||
| def __init__(self, batch_size, device_num): | |||
| super(DistAccuracy, self).__init__() | |||
| self.clear() | |||
| self.batch_size = batch_size | |||
| self.device_num = device_num | |||
| def clear(self): | |||
| """Clears the internal evaluation result.""" | |||
| self._correct_num = 0 | |||
| self._total_num = 0 | |||
| def update(self, *inputs): | |||
| """ | |||
| Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. | |||
| Args: | |||
| inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`. | |||
| `y_correct` is the right prediction count that gathered from all devices | |||
| it's a scalar in float type | |||
| Raises: | |||
| ValueError: If the number of the input is not 1. | |||
| """ | |||
| if len(inputs) != 1: | |||
| raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs))) | |||
| y_correct = self._convert_data(inputs[0]) | |||
| self._correct_num += y_correct | |||
| self._total_num += self.batch_size * self.device_num | |||
| def eval(self): | |||
| """ | |||
| Computes the accuracy. | |||
| Returns: | |||
| Float, the computed result. | |||
| Raises: | |||
| RuntimeError: If the sample size is 0. | |||
| """ | |||
| if self._total_num == 0: | |||
| raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.') | |||
| return self._correct_num / self._total_num | |||
| @@ -31,9 +31,12 @@ from mindspore.common import set_seed | |||
| from mindspore.parallel import set_algo_parameters | |||
| import mindspore.nn as nn | |||
| import mindspore.common.initializer as weight_init | |||
| import mindspore.log as logger | |||
| from src.lr_generator import get_lr, warmup_cosine_annealing_lr | |||
| from src.CrossEntropySmooth import CrossEntropySmooth | |||
| from src.config import cfg | |||
| from src.eval_callback import EvalCallBack | |||
| from src.metric import DistAccuracy, ClassifyCorrectCell | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--net', type=str, default=None, help='Resnet Model, resnet18, resnet50 or resnet101') | |||
| @@ -48,6 +51,15 @@ parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained ch | |||
| parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') | |||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | |||
| help="Filter head weight parameters, default is False.") | |||
| parser.add_argument("--run_eval", type=ast.literal_eval, default=False, | |||
| help="Run evaluation when training, default is False.") | |||
| parser.add_argument('--eval_dataset_path', type=str, default=None, help='Evaluation dataset path when run_eval is True') | |||
| parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, | |||
| help="Save best checkpoint when run_eval is True, default is True.") | |||
| parser.add_argument("--eval_start_epoch", type=int, default=40, | |||
| help="Evaluation start epoch when run_eval is True, default is 40.") | |||
| parser.add_argument("--eval_interval", type=int, default=1, | |||
| help="Evaluation interval when run_eval is True, default is 1.") | |||
| args_opt = parser.parse_args() | |||
| set_seed(1) | |||
| @@ -89,6 +101,12 @@ def filter_checkpoint_parameter_by_list(origin_dict, param_filter): | |||
| del origin_dict[key] | |||
| break | |||
| def apply_eval(eval_param): | |||
| eval_model = eval_param["model"] | |||
| eval_ds = eval_param["dataset"] | |||
| metrics_name = eval_param["metrics_name"] | |||
| res = eval_model.eval(eval_ds) | |||
| return res[metrics_name] | |||
| if __name__ == '__main__': | |||
| target = args_opt.device_target | |||
| @@ -185,12 +203,16 @@ if __name__ == '__main__': | |||
| else: | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False) | |||
| dist_eval_network = ClassifyCorrectCell(net) if args_opt.run_distribute else None | |||
| metrics = {"acc"} | |||
| if args_opt.run_distribute: | |||
| metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=args_opt.device_num)} | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, | |||
| amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network) | |||
| if (args_opt.net != "resnet101" and args_opt.net != "resnet50") or \ | |||
| args_opt.parameter_server or target == "CPU": | |||
| ## fp32 training | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network) | |||
| if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012": | |||
| from src.lr_generator import get_thor_damping | |||
| damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size) | |||
| @@ -201,6 +223,8 @@ if __name__ == '__main__': | |||
| loss_scale_manager=loss_scale, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False, | |||
| frequency=config.frequency) | |||
| args_opt.run_eval = False | |||
| logger.warning("Thor optimizer not support evaluation while training.") | |||
| # define callbacks | |||
| time_cb = TimeMonitor(data_size=step_size) | |||
| @@ -211,7 +235,17 @@ if __name__ == '__main__': | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) | |||
| cb += [ckpt_cb] | |||
| if args_opt.run_eval: | |||
| if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)): | |||
| raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path)) | |||
| eval_dataset = create_dataset(dataset_path=args_opt.eval_dataset_path, do_train=False, | |||
| batch_size=config.batch_size, target=target) | |||
| eval_param_dict = {"model": model, "dataset": eval_dataset, "metrics_name": "acc"} | |||
| eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, | |||
| eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, | |||
| ckpt_directory=ckpt_save_dir, besk_ckpt_name="best_acc.ckpt", | |||
| metrics_name="acc") | |||
| cb += [eval_cb] | |||
| # train model | |||
| if args_opt.net == "se-resnet50": | |||
| config.epoch_size = config.train_epoch_size | |||
| @@ -123,8 +123,8 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>) | |||
| ### Prepare the model | |||
| 1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`. | |||
| 2. Change the dataset config in the corresponding config. `src/config_ssd300.py` or `src/config_ssd_mobilenet_v1_fpn.py`. | |||
| 1. Chose the model by changing the `using_model` in `src/confgi.py`. The optional models are: `ssd300`, `ssd_mobilenet_v1_fpn`, `ssd_mobilenet_v1_fpn`, `ssd_resnet50_fpn`. | |||
| 2. Change the dataset config in the corresponding config. `src/config_ssd300.py`, `src/config_ssd_mobilenet_v1_fpn.py`, `src/config_ssd_resnet50_fpn.py`, `src/config_ssd_vgg16.py`. | |||
| 3. If you are running with `ssd_mobilenet_v1_fpn`, you need a pretrained model for `mobilenet_v1`. Set the checkpoint path to `feature_extractor_base_param` in `src/config_ssd_mobilenet_v1_fpn.py`. For more detail about training mobilnet_v1, please refer to the mobilenetv1 model. | |||
| ### Run the scripts | |||
| @@ -201,6 +201,7 @@ Then you can run everything just like on ascend. | |||
| ├─ src | |||
| ├─ __init__.py # init file | |||
| ├─ box_utils.py # bbox utils | |||
| ├─ eval_callback.py # evaluation callback when training | |||
| ├─ eval_utils.py # metrics utils | |||
| ├─ config.py # total config | |||
| ├─ dataset.py # create dataset and process dataset | |||
| @@ -229,6 +230,10 @@ Then you can run everything just like on ascend. | |||
| "loss_scale": 1024 # Loss scale | |||
| "filter_weight": False # Load parameters in head layer or not. If the class numbers of train dataset is different from the class numbers in pre_trained checkpoint, please set True. | |||
| "freeze_layer": "none" # Freeze the backbone parameters or not, support none and backbone. | |||
| "run_eval": False # Run evaluation when training | |||
| "save_best_ckpt": True # Save best checkpoint when run_eval is True | |||
| "eval_start_epoch": 40 # Evaluation start epoch when run_eval is True | |||
| "eval_interval": 1 # valuation interval when run_eval is True | |||
| "class_num": 81 # Dataset class number | |||
| "image_shape": [300, 300] # Image height and width used as input to the model | |||
| @@ -311,6 +316,10 @@ epoch time: 150753.701, per step time: 329.157 | |||
| ... | |||
| ``` | |||
| #### Evaluation while training | |||
| You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. | |||
| #### Transfer Training | |||
| You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps. | |||
| @@ -17,14 +17,12 @@ | |||
| import os | |||
| import argparse | |||
| import time | |||
| import numpy as np | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.ssd import SSD300, SsdInferWithDecoder, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 | |||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||
| from src.config import config | |||
| from src.eval_utils import metrics | |||
| from src.eval_utils import apply_eval | |||
| from src.box_utils import default_boxes | |||
| def ssd_eval(dataset_path, ckpt_path, anno_json): | |||
| @@ -50,31 +48,12 @@ def ssd_eval(dataset_path, ckpt_path, anno_json): | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| i = batch_size | |||
| total = ds.get_dataset_size() * batch_size | |||
| start = time.time() | |||
| pred_data = [] | |||
| print("\n========================================\n") | |||
| print("total images num: ", total) | |||
| print("Processing, please wait a moment.") | |||
| for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| img_id = data['img_id'] | |||
| img_np = data['image'] | |||
| image_shape = data['image_shape'] | |||
| output = net(Tensor(img_np)) | |||
| for batch_idx in range(img_np.shape[0]): | |||
| pred_data.append({"boxes": output[0].asnumpy()[batch_idx], | |||
| "box_scores": output[1].asnumpy()[batch_idx], | |||
| "img_id": int(np.squeeze(img_id[batch_idx])), | |||
| "image_shape": image_shape[batch_idx]}) | |||
| percent = round(i / total * 100., 2) | |||
| print(f' {str(percent)} [{i}/{total}]', end='\r') | |||
| i += batch_size | |||
| cost_time = int((time.time() - start) * 1000) | |||
| print(f' 100% [{total}/{total}] cost {cost_time} ms') | |||
| mAP = metrics(pred_data, anno_json) | |||
| eval_param_dict = {"net": net, "dataset": ds, "anno_json": anno_json} | |||
| mAP = apply_eval(eval_param_dict) | |||
| print("\n========================================\n") | |||
| print(f"mAP: {mAP}") | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation callback when training""" | |||
| import os | |||
| import stat | |||
| from mindspore import save_checkpoint | |||
| from mindspore import log as logger | |||
| from mindspore.train.callback import Callback | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Evaluation callback when training. | |||
| Args: | |||
| eval_function (function): evaluation function. | |||
| eval_param_dict (dict): evaluation parameters' configure dict. | |||
| interval (int): run evaluation interval, default is 1. | |||
| eval_start_epoch (int): evaluation start epoch, default is 1. | |||
| save_best_ckpt (bool): Whether to save best checkpoint, default is True. | |||
| besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | |||
| metrics_name (str): evaluation metrics name, default is `acc`. | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> EvalCallBack(eval_function, eval_param_dict) | |||
| """ | |||
| def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | |||
| ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | |||
| super(EvalCallBack, self).__init__() | |||
| self.eval_param_dict = eval_param_dict | |||
| self.eval_function = eval_function | |||
| self.eval_start_epoch = eval_start_epoch | |||
| if interval < 1: | |||
| raise ValueError("interval should >= 1.") | |||
| self.interval = interval | |||
| self.save_best_ckpt = save_best_ckpt | |||
| self.best_res = 0 | |||
| self.best_epoch = 0 | |||
| if not os.path.isdir(ckpt_directory): | |||
| os.makedirs(ckpt_directory) | |||
| self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | |||
| self.metrics_name = metrics_name | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def epoch_end(self, run_context): | |||
| """Callback when epoch end.""" | |||
| cb_params = run_context.original_args() | |||
| cur_epoch = cb_params.cur_epoch_num | |||
| if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | |||
| res = self.eval_function(self.eval_param_dict) | |||
| print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | |||
| if res >= self.best_res: | |||
| self.best_res = res | |||
| self.best_epoch = cur_epoch | |||
| print("update best result: {}".format(res), flush=True) | |||
| if self.save_best_ckpt: | |||
| if os.path.exists(self.bast_ckpt_path): | |||
| self.remove_ckpoint_file(self.bast_ckpt_path) | |||
| save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | |||
| print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | |||
| def end(self, run_context): | |||
| print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | |||
| self.best_res, | |||
| self.best_epoch), flush=True) | |||
| @@ -16,8 +16,28 @@ | |||
| import json | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from .config import config | |||
| def apply_eval(eval_param_dict): | |||
| net = eval_param_dict["net"] | |||
| net.set_train(False) | |||
| ds = eval_param_dict["dataset"] | |||
| anno_json = eval_param_dict["anno_json"] | |||
| pred_data = [] | |||
| for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| img_id = data['img_id'] | |||
| img_np = data['image'] | |||
| image_shape = data['image_shape'] | |||
| output = net(Tensor(img_np)) | |||
| for batch_idx in range(img_np.shape[0]): | |||
| pred_data.append({"boxes": output[0].asnumpy()[batch_idx], | |||
| "box_scores": output[1].asnumpy()[batch_idx], | |||
| "img_id": int(np.squeeze(img_id[batch_idx])), | |||
| "image_shape": image_shape[batch_idx]}) | |||
| mAP = metrics(pred_data, anno_json) | |||
| return mAP | |||
| def apply_nms(all_boxes, all_scores, thres, max_boxes): | |||
| """Apply NMS to bboxes.""" | |||
| @@ -15,6 +15,7 @@ | |||
| """Train SSD and get checkpoint files.""" | |||
| import os | |||
| import argparse | |||
| import ast | |||
| import mindspore.nn as nn | |||
| @@ -25,11 +26,15 @@ from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import set_seed, dtype | |||
| from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 | |||
| from src.ssd import SSD300, SsdInferWithDecoder, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2,\ | |||
| ssd_mobilenet_v1_fpn, ssd_resnet50_fpn, ssd_vgg16 | |||
| from src.config import config | |||
| from src.dataset import create_ssd_dataset, create_mindrecord | |||
| from src.lr_schedule import get_lr | |||
| from src.init_params import init_net_param, filter_checkpoint_parameter_by_list | |||
| from src.eval_callback import EvalCallBack | |||
| from src.eval_utils import apply_eval | |||
| from src.box_utils import default_boxes | |||
| set_seed(1) | |||
| @@ -57,6 +62,14 @@ def get_args(): | |||
| parser.add_argument('--freeze_layer', type=str, default="none", choices=["none", "backbone"], | |||
| help="freeze the weights of network, support freeze the backbone's weights, " | |||
| "default is not freezing.") | |||
| parser.add_argument("--run_eval", type=ast.literal_eval, default=False, | |||
| help="Run evaluation when training, default is False.") | |||
| parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, | |||
| help="Save best checkpoint when run_eval is True, default is True.") | |||
| parser.add_argument("--eval_start_epoch", type=int, default=40, | |||
| help="Evaluation start epoch when run_eval is True, default is 40.") | |||
| parser.add_argument("--eval_interval", type=int, default=1, | |||
| help="Evaluation interval when run_eval is True, default is 1.") | |||
| args_opt = parser.parse_args() | |||
| return args_opt | |||
| @@ -170,8 +183,25 @@ def main(): | |||
| config.momentum, config.weight_decay, loss_scale) | |||
| net = TrainingWrapper(net, opt, loss_scale) | |||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | |||
| if args_opt.run_eval: | |||
| eval_net = SsdInferWithDecoder(ssd, Tensor(default_boxes), config) | |||
| eval_net.set_train(False) | |||
| mindrecord_file = create_mindrecord(args_opt.dataset, "ssd_eval.mindrecord", False) | |||
| eval_dataset = create_ssd_dataset(mindrecord_file, batch_size=args_opt.batch_size, repeat_num=1, | |||
| is_training=False, use_multiprocessing=False) | |||
| if args_opt.dataset == "coco": | |||
| anno_json = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type)) | |||
| elif args_opt.dataset == "voc": | |||
| anno_json = os.path.join(config.voc_root, config.voc_json) | |||
| else: | |||
| raise ValueError('SSD eval only support dataset mode is coco and voc!') | |||
| eval_param_dict = {"net": eval_net, "dataset": eval_dataset, "anno_json": anno_json} | |||
| eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, | |||
| eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, | |||
| ckpt_directory=save_ckpt_path, besk_ckpt_name="best_map.ckpt", | |||
| metrics_name="mAP") | |||
| callback.append(eval_cb) | |||
| model = Model(net) | |||
| dataset_sink_mode = False | |||
| if args_opt.mode == "sink" and args_opt.run_platform != "CPU": | |||
| @@ -128,6 +128,7 @@ Then you can run everything just like on ascend. | |||
| │ ├──config.py // parameter configuration | |||
| │ ├──data_loader.py // creating dataset | |||
| │ ├──loss.py // loss | |||
| │ ├──eval_callback.py // evaluation callback while training | |||
| │ ├──utils.py // General components (callback function) | |||
| │ ├──unet_medical // Unet medical architecture | |||
| ├──__init__.py // init file | |||
| @@ -168,6 +169,11 @@ Parameters for both training and evaluation can be set in config.py | |||
| 'resume_ckpt': './', # pretrain model path | |||
| 'transfer_training': False # whether do transfer training | |||
| 'filter_weight': ["final.weight"] # weight name to filter while doing transfer training | |||
| 'run_eval': False # Run evaluation when training | |||
| 'save_best_ckpt': True # Save best checkpoint when run_eval is True | |||
| 'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True | |||
| 'eval_interval': 1 # valuation interval when run_eval is True | |||
| ``` | |||
| - config for Unet++, cell nuclei dataset | |||
| @@ -193,6 +199,10 @@ Parameters for both training and evaluation can be set in config.py | |||
| 'resume_ckpt': './', # pretrain model path | |||
| 'transfer_training': False # whether do transfer training | |||
| 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] # weight name to filter while doing transfer training | |||
| 'run_eval': False # Run evaluation when training | |||
| 'save_best_ckpt': True # Save best checkpoint when run_eval is True | |||
| 'eval_start_epoch': 0 # Evaluation start epoch when run_eval is True | |||
| 'eval_interval': 1 # valuation interval when run_eval is True | |||
| ``` | |||
| ## [Training Process](#contents) | |||
| @@ -245,6 +255,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891 | |||
| step: 300, loss is 0.18949677, fps is 57.63118508760329 | |||
| ``` | |||
| #### Evaluation while training | |||
| You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` when `run_eval` is True. | |||
| ## [Evaluation Process](#contents) | |||
| ### Evaluation | |||
| @@ -132,6 +132,7 @@ bash scripts/docker_start.sh unet:20.1.0 [DATA_DIR] [MODEL_DIR] | |||
| │ ├──config.py // 参数配置 | |||
| │ ├──data_loader.py // 数据处理 | |||
| │ ├──loss.py // 损失函数 | |||
| │ ├─ eval_callback.py // 训练时推理回调函数 | |||
| │ ├──utils.py // 通用组件(回调函数) | |||
| │ ├──unet_medical // 医学图像处理Unet结构 | |||
| ├──__init__.py | |||
| @@ -247,6 +248,10 @@ step: 299, loss is 0.20551169, fps is 58.4039329983891 | |||
| step: 300, loss is 0.18949677, fps is 57.63118508760329 | |||
| ``` | |||
| #### 训练时推理 | |||
| 训练时推理需要在启动文件中添加`run_eval` 并设置为True。与此同时需要设置: `save_best_ckpt`, `eval_start_epoch`, `eval_interval`, `eval_metrics` 。 | |||
| ## 评估过程 | |||
| ### 评估 | |||
| @@ -16,10 +16,6 @@ | |||
| import os | |||
| import argparse | |||
| import logging | |||
| import cv2 | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as F | |||
| from mindspore import context, Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| @@ -27,76 +23,11 @@ from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||
| from src.unet_medical import UNetMedical | |||
| from src.unet_nested import NestedUNet, UNet | |||
| from src.config import cfg_unet | |||
| from src.utils import UnetEval | |||
| from src.utils import UnetEval, TempLoss, dice_coeff | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) | |||
| class TempLoss(nn.Cell): | |||
| """A temp loss cell.""" | |||
| def __init__(self): | |||
| super(TempLoss, self).__init__() | |||
| self.identity = F.identity() | |||
| def construct(self, logits, label): | |||
| return self.identity(logits) | |||
| class dice_coeff(nn.Metric): | |||
| def __init__(self): | |||
| super(dice_coeff, self).__init__() | |||
| self.clear() | |||
| def clear(self): | |||
| self._dice_coeff_sum = 0 | |||
| self._iou_sum = 0 | |||
| self._samples_num = 0 | |||
| def update(self, *inputs): | |||
| if len(inputs) != 2: | |||
| raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) | |||
| y = self._convert_data(inputs[1]) | |||
| self._samples_num += y.shape[0] | |||
| y = y.transpose(0, 2, 3, 1) | |||
| b, h, w, c = y.shape | |||
| if b != 1: | |||
| raise ValueError('Batch size should be 1 when in evaluation.') | |||
| y = y.reshape((h, w, c)) | |||
| if cfg_unet["eval_activate"].lower() == "softmax": | |||
| y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) | |||
| if cfg_unet["eval_resize"]: | |||
| y_pred = [] | |||
| for i in range(cfg_unet["num_classes"]): | |||
| y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| y_pred = y_softmax | |||
| elif cfg_unet["eval_activate"].lower() == "argmax": | |||
| y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) | |||
| y_pred = [] | |||
| for i in range(cfg_unet["num_classes"]): | |||
| if cfg_unet["eval_resize"]: | |||
| y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) | |||
| else: | |||
| y_pred.append(np.float32(y_argmax == i)) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| raise ValueError('config eval_activate should be softmax or argmax.') | |||
| y_pred = y_pred.astype(np.float32) | |||
| inter = np.dot(y_pred.flatten(), y.flatten()) | |||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | |||
| single_dice_coeff = 2*float(inter)/float(union+1e-6) | |||
| single_iou = single_dice_coeff / (2 - single_dice_coeff) | |||
| print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) | |||
| self._dice_coeff_sum += single_dice_coeff | |||
| self._iou_sum += single_iou | |||
| def eval(self): | |||
| if self._samples_num == 0: | |||
| raise RuntimeError('Total samples num must not be 0.') | |||
| return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) | |||
| def test_net(data_dir, | |||
| ckpt_path, | |||
| cross_valid_ind=1, | |||
| @@ -119,7 +50,7 @@ def test_net(data_dir, | |||
| else: | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, | |||
| do_crop=cfg['crop'], img_size=cfg['img_size']) | |||
| model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()}) | |||
| model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff(cfg_unet)}) | |||
| print("============== Starting Evaluating ============") | |||
| eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation callback when training""" | |||
| import os | |||
| import stat | |||
| from mindspore import save_checkpoint | |||
| from mindspore import log as logger | |||
| from mindspore.train.callback import Callback | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Evaluation callback when training. | |||
| Args: | |||
| eval_function (function): evaluation function. | |||
| eval_param_dict (dict): evaluation parameters' configure dict. | |||
| interval (int): run evaluation interval, default is 1. | |||
| eval_start_epoch (int): evaluation start epoch, default is 1. | |||
| save_best_ckpt (bool): Whether to save best checkpoint, default is True. | |||
| besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | |||
| metrics_name (str): evaluation metrics name, default is `acc`. | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> EvalCallBack(eval_function, eval_param_dict) | |||
| """ | |||
| def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | |||
| ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | |||
| super(EvalCallBack, self).__init__() | |||
| self.eval_param_dict = eval_param_dict | |||
| self.eval_function = eval_function | |||
| self.eval_start_epoch = eval_start_epoch | |||
| if interval < 1: | |||
| raise ValueError("interval should >= 1.") | |||
| self.interval = interval | |||
| self.save_best_ckpt = save_best_ckpt | |||
| self.best_res = 0 | |||
| self.best_epoch = 0 | |||
| if not os.path.isdir(ckpt_directory): | |||
| os.makedirs(ckpt_directory) | |||
| self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | |||
| self.metrics_name = metrics_name | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def epoch_end(self, run_context): | |||
| """Callback when epoch end.""" | |||
| cb_params = run_context.original_args() | |||
| cur_epoch = cb_params.cur_epoch_num | |||
| if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | |||
| res = self.eval_function(self.eval_param_dict) | |||
| print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | |||
| if res >= self.best_res: | |||
| self.best_res = res | |||
| self.best_epoch = cur_epoch | |||
| print("update best result: {}".format(res), flush=True) | |||
| if self.save_best_ckpt: | |||
| if os.path.exists(self.bast_ckpt_path): | |||
| self.remove_ckpoint_file(self.bast_ckpt_path) | |||
| save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | |||
| print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | |||
| def end(self, run_context): | |||
| print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | |||
| self.best_res, | |||
| self.best_epoch), flush=True) | |||
| @@ -41,7 +41,7 @@ class MultiCrossEntropyWithLogits(nn.Cell): | |||
| def __init__(self): | |||
| super(MultiCrossEntropyWithLogits, self).__init__() | |||
| self.loss = CrossEntropyWithLogits() | |||
| self.squeeze = F.Squeeze() | |||
| self.squeeze = F.Squeeze(axis=0) | |||
| def construct(self, logits, label): | |||
| total_loss = 0 | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| import time | |||
| import cv2 | |||
| import numpy as np | |||
| from PIL import Image | |||
| from mindspore import nn | |||
| @@ -25,20 +26,100 @@ class UnetEval(nn.Cell): | |||
| """ | |||
| Add Unet evaluation activation. | |||
| """ | |||
| def __init__(self, net): | |||
| def __init__(self, net, need_slice=False): | |||
| super(UnetEval, self).__init__() | |||
| self.net = net | |||
| self.need_slice = need_slice | |||
| self.transpose = ops.Transpose() | |||
| self.softmax = ops.Softmax(axis=-1) | |||
| self.argmax = ops.Argmax(axis=-1) | |||
| self.squeeze = ops.Squeeze(axis=0) | |||
| def construct(self, x): | |||
| out = self.net(x) | |||
| if self.need_slice: | |||
| out = self.squeeze(out[-1:]) | |||
| out = self.transpose(out, (0, 2, 3, 1)) | |||
| softmax_out = self.softmax(out) | |||
| argmax_out = self.argmax(out) | |||
| return (softmax_out, argmax_out) | |||
| class TempLoss(nn.Cell): | |||
| """A temp loss cell.""" | |||
| def __init__(self): | |||
| super(TempLoss, self).__init__() | |||
| self.identity = ops.identity() | |||
| def construct(self, logits, label): | |||
| return self.identity(logits) | |||
| def apply_eval(eval_param_dict): | |||
| """run Evaluation""" | |||
| model = eval_param_dict["model"] | |||
| dataset = eval_param_dict["dataset"] | |||
| metrics_name = eval_param_dict["metrics_name"] | |||
| index = 0 if metrics_name == "dice_coeff" else 1 | |||
| eval_score = model.eval(dataset, dataset_sink_mode=False)[metrics_name][index] | |||
| return eval_score | |||
| class dice_coeff(nn.Metric): | |||
| """Unet Metric, return dice coefficient and IOU.""" | |||
| def __init__(self, cfg_unet, print_res=True): | |||
| super(dice_coeff, self).__init__() | |||
| self.clear() | |||
| self.cfg_unet = cfg_unet | |||
| self.print_res = print_res | |||
| def clear(self): | |||
| self._dice_coeff_sum = 0 | |||
| self._iou_sum = 0 | |||
| self._samples_num = 0 | |||
| def update(self, *inputs): | |||
| if len(inputs) != 2: | |||
| raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) | |||
| y = self._convert_data(inputs[1]) | |||
| self._samples_num += y.shape[0] | |||
| y = y.transpose(0, 2, 3, 1) | |||
| b, h, w, c = y.shape | |||
| if b != 1: | |||
| raise ValueError('Batch size should be 1 when in evaluation.') | |||
| y = y.reshape((h, w, c)) | |||
| if self.cfg_unet["eval_activate"].lower() == "softmax": | |||
| y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) | |||
| if self.cfg_unet["eval_resize"]: | |||
| y_pred = [] | |||
| for i in range(self.cfg_unet["num_classes"]): | |||
| y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| y_pred = y_softmax | |||
| elif self.cfg_unet["eval_activate"].lower() == "argmax": | |||
| y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) | |||
| y_pred = [] | |||
| for i in range(self.cfg_unet["num_classes"]): | |||
| if self.cfg_unet["eval_resize"]: | |||
| y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) | |||
| else: | |||
| y_pred.append(np.float32(y_argmax == i)) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| raise ValueError('config eval_activate should be softmax or argmax.') | |||
| y_pred = y_pred.astype(np.float32) | |||
| inter = np.dot(y_pred.flatten(), y.flatten()) | |||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | |||
| single_dice_coeff = 2 * float(inter) / float(union+1e-6) | |||
| single_iou = single_dice_coeff / (2 - single_dice_coeff) | |||
| if self.print_res: | |||
| print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) | |||
| self._dice_coeff_sum += single_dice_coeff | |||
| self._iou_sum += single_iou | |||
| def eval(self): | |||
| if self._samples_num == 0: | |||
| raise RuntimeError('Total samples num must not be 0.') | |||
| return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) | |||
| class StepLossTimeMonitor(Callback): | |||
| def __init__(self, batch_size, per_print_times=1): | |||
| @@ -30,23 +30,25 @@ from src.unet_medical import UNetMedical | |||
| from src.unet_nested import NestedUNet, UNet | |||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||
| from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits | |||
| from src.utils import StepLossTimeMonitor, filter_checkpoint_parameter_by_list | |||
| from src.utils import StepLossTimeMonitor, UnetEval, TempLoss, apply_eval, filter_checkpoint_parameter_by_list, dice_coeff | |||
| from src.config import cfg_unet | |||
| from src.eval_callback import EvalCallBack | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) | |||
| mindspore.set_seed(1) | |||
| def train_net(data_dir, | |||
| def train_net(args_opt, | |||
| cross_valid_ind=1, | |||
| epochs=400, | |||
| batch_size=16, | |||
| lr=0.0001, | |||
| run_distribute=False, | |||
| cfg=None): | |||
| rank = 0 | |||
| group_size = 1 | |||
| data_dir = args_opt.data_url | |||
| run_distribute = args_opt.run_distribute | |||
| if run_distribute: | |||
| init() | |||
| group_size = get_group_size() | |||
| @@ -55,12 +57,13 @@ def train_net(data_dir, | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, | |||
| device_num=group_size, | |||
| gradients_mean=False) | |||
| need_slice = False | |||
| if cfg['model'] == 'unet_medical': | |||
| net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||
| elif cfg['model'] == 'unet_nested': | |||
| net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'], | |||
| use_bn=cfg['use_bn'], use_ds=cfg['use_ds']) | |||
| need_slice = cfg['use_ds'] | |||
| elif cfg['model'] == 'unet_simple': | |||
| net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes']) | |||
| else: | |||
| @@ -83,12 +86,15 @@ def train_net(data_dir, | |||
| train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size, | |||
| is_train=True, augment=True, split=0.8, rank=rank, | |||
| group_size=group_size) | |||
| valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, | |||
| eval_resize=cfg["eval_resize"], split=0.8, | |||
| python_multiprocessing=False) | |||
| else: | |||
| repeat = epochs | |||
| dataset_sink_mode = False | |||
| per_print_times = 1 | |||
| train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute, | |||
| cfg["crop"], cfg['img_size']) | |||
| train_dataset, valid_dataset = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, | |||
| run_distribute, cfg["crop"], cfg['img_size']) | |||
| train_data_size = train_dataset.get_dataset_size() | |||
| print("dataset length is:", train_data_size) | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, | |||
| @@ -106,6 +112,15 @@ def train_net(data_dir, | |||
| print("============== Starting Training ==============") | |||
| callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb] | |||
| if args_opt.run_eval: | |||
| eval_model = Model(UnetEval(net, need_slice=need_slice), loss_fn=TempLoss(), | |||
| metrics={"dice_coeff": dice_coeff(cfg_unet, False)}) | |||
| eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": args_opt.eval_metrics} | |||
| eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, | |||
| eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, | |||
| ckpt_directory='./ckpt_{}/'.format(device_id), besk_ckpt_name="best.ckpt", | |||
| metrics_name=args_opt.eval_metrics) | |||
| callbacks.append(eval_cb) | |||
| model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) | |||
| print("============== End Training ==============") | |||
| @@ -117,6 +132,17 @@ def get_args(): | |||
| help='data directory') | |||
| parser.add_argument('-t', '--run_distribute', type=ast.literal_eval, | |||
| default=False, help='Run distribute, default: false.') | |||
| parser.add_argument("--run_eval", type=ast.literal_eval, default=False, | |||
| help="Run evaluation when training, default is False.") | |||
| parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, | |||
| help="Save best checkpoint when run_eval is True, default is True.") | |||
| parser.add_argument("--eval_start_epoch", type=int, default=0, | |||
| help="Evaluation start epoch when run_eval is True, default is 0.") | |||
| parser.add_argument("--eval_interval", type=int, default=1, | |||
| help="Evaluation interval when run_eval is True, default is 1.") | |||
| parser.add_argument("--eval_metrics", type=str, default="dice_coeff", choices=("dice_coeff", "iou"), | |||
| help="Evaluation metrics when run_eval is True, support [dice_coeff, iou], " | |||
| "default is dice_coeff.") | |||
| return parser.parse_args() | |||
| @@ -127,10 +153,9 @@ if __name__ == '__main__': | |||
| print("Training setting:", args) | |||
| epoch_size = cfg_unet['epochs'] if not args.run_distribute else cfg_unet['distribute_epochs'] | |||
| train_net(data_dir=args.data_url, | |||
| train_net(args_opt=args, | |||
| cross_valid_ind=cfg_unet['cross_valid_ind'], | |||
| epochs=epoch_size, | |||
| batch_size=cfg_unet['batchsize'], | |||
| lr=cfg_unet['lr'], | |||
| run_distribute=args.run_distribute, | |||
| cfg=cfg_unet) | |||