| @@ -4,7 +4,7 @@ | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Quick Start](#quick-start) | |||
| - [Quick Start](#quick-start) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| @@ -17,97 +17,89 @@ | |||
| - [Performance](#performance) | |||
| - [Evaluation Performance](#evaluation-performance) | |||
| - [How to use](#how-to-use) | |||
| - [Inference](#inference) | |||
| - [Inference](#inference) | |||
| - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [Unet Description](#contents) | |||
| ## [Unet Description](#contents) | |||
| Unet Medical model for 2D image segmentation. This implementation is as described in the original paper [UNet: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597). Unet, in the 2015 ISBI cell tracking competition, many of the best are obtained. In this paper, a network model for medical image segmentation is proposed, and a data enhancement method is proposed to effectively use the annotation data to solve the problem of insufficient annotation data in the medical field. A U-shaped network structure is also used to extract the context and location information. | |||
| [Paper](https://arxiv.org/abs/1505.04597): Olaf Ronneberger, Philipp Fischer, Thomas Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." * conditionally accepted at MICCAI 2015*. 2015. | |||
| [Paper](https://arxiv.org/abs/1505.04597): Olaf Ronneberger, Philipp Fischer, Thomas Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." *conditionally accepted at MICCAI 2015*. 2015. | |||
| # [Model Architecture](#contents) | |||
| Specifically, the U network structure is proposed in UNET, which can better extract and fuse high-level features and obtain context information and spatial location information. The U network structure is composed of encoder and decoder. The encoder is composed of two 3x3 conv and a 2x2 max pooling iteration. The number of channels is doubled after each down sampling. The decoder is composed of a 2x2 deconv, concat layer and two 3x3 convolutions, and then outputs after a 1x1 convolution. | |||
| # [Dataset](#contents) | |||
| Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home) | |||
| Dataset used: [ISBI Challenge](http://brainiac2.mit.edu/isbi_challenge/home) | |||
| - Description: The training and test datasets are two stacks of 30 sections from a serial section Transmission Electron Microscopy (ssTEM) data set of the Drosophila first instar larva ventral nerve cord (VNC). The microcube measures 2 x 2 x 1.5 microns approx., with a resolution of 4x4x50 nm/pixel. | |||
| - License: You are free to use this data set for the purpose of generating or testing non-commercial image segmentation software. If any scientific publications derive from the usage of this data set, you must cite TrakEM2 and the following publication: Cardona A, Saalfeld S, Preibisch S, Schmid B, Cheng A, Pulokas J, Tomancak P, Hartenstein V. 2010. An Integrated Micro- and Macroarchitectural Analysis of the Drosophila Brain by Computer-Assisted Serial Section Electron Microscopy. PLoS Biol 8(10): e1000502. doi:10.1371/journal.pbio.1000502. | |||
| - Dataset size:22.5M, | |||
| - Train:15M, 30 images (Training data contains 2 multi-page TIF files, each containing 30 2D-images. train-volume.tif and train-labels.tif respectly contain data and label.) | |||
| - Val:(We randomly divde the training data into 5-fold and evaluate the model by across 5-fold cross-validation.) | |||
| - Test:7.5M, 30 images (Testing data contains 1 multi-page TIF files, each containing 30 2D-images. test-volume.tif respectly contain data.) | |||
| - Dataset size:22.5M, | |||
| - Train:15M, 30 images (Training data contains 2 multi-page TIF files, each containing 30 2D-images. train-volume.tif and train-labels.tif respectly contain data and label.) | |||
| - Val:(We randomly divide the training data into 5-fold and evaluate the model by across 5-fold cross-validation.) | |||
| - Test:7.5M, 30 images (Testing data contains 1 multi-page TIF files, each containing 30 2D-images. test-volume.tif respectly contain data.) | |||
| - Data format:binary files(TIF file) | |||
| - Note:Data will be processed in src/data_loader.py | |||
| - Note:Data will be processed in src/data_loader.py | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| # [Quick Start](#contents) | |||
| After installing MindSpore via the official website, you can start training and evaluation as follows: | |||
| After installing MindSpore via the official website, you can start training and evaluation as follows: | |||
| - running on Ascend | |||
| ```python | |||
| # run training example | |||
| python train.py --data_url=/path/to/data/ > train.log 2>&1 & | |||
| python train.py --data_url=/path/to/data/ > train.log 2>&1 & | |||
| OR | |||
| bash scripts/run_standalone_train.sh [DATASET] | |||
| # run distributed training example | |||
| bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] | |||
| # run evaluation example | |||
| python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 & | |||
| python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 & | |||
| OR | |||
| bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] | |||
| ``` | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| ``` | |||
| ```text | |||
| ├── model_zoo | |||
| ├── README.md // descriptions about all the models | |||
| ├── unet | |||
| ├── unet | |||
| ├── README.md // descriptions about Unet | |||
| ├── scripts | |||
| ├── scripts | |||
| │ ├──run_standalone_train.sh // shell script for distributed on Ascend | |||
| │ ├──run_standalone_eval.sh // shell script for evaluation on Ascend | |||
| ├── src | |||
| ├── src | |||
| │ ├──config.py // parameter configuration | |||
| │ ├──data_loader.py // creating dataset | |||
| │ ├──loss.py // loss | |||
| │ ├──loss.py // loss | |||
| │ ├──utils.py // General components (callback function) | |||
| │ ├──unet.py // Unet architecture | |||
| ├──__init__.py // init file | |||
| ├──unet_model.py // unet model | |||
| ├──unet_model.py // unet model | |||
| ├──unet_parts.py // unet part | |||
| ├── train.py // training script | |||
| ├──launch_8p.py // training 8P script | |||
| ├── eval.py // evaluation script | |||
| ├── train.py // training script | |||
| ├──launch_8p.py // training 8P script | |||
| ├── eval.py // evaluation script | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| @@ -133,24 +125,24 @@ Parameters for both training and evaluation can be set in config.py | |||
| 'resume_ckpt': './', # pretrain model path | |||
| ``` | |||
| ## [Training Process](#contents) | |||
| ### Training | |||
| ### Training | |||
| - running on Ascend | |||
| ``` | |||
| python train.py --data_url=/path/to/data/ > train.log 2>&1 & | |||
| ```shell | |||
| python train.py --data_url=/path/to/data/ > train.log 2>&1 & | |||
| OR | |||
| bash scripts/run_standalone_train.sh [DATASET] | |||
| ``` | |||
| The python command above will run in the background, you can view the results through the file `train.log`. | |||
| After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: | |||
| ``` | |||
| ```shell | |||
| # grep "loss is " train.log | |||
| step: 1, loss is 0.7011719, fps is 0.25025035060906264 | |||
| step: 2, loss is 0.69433594, fps is 56.77693756377044 | |||
| @@ -163,19 +155,20 @@ Parameters for both training and evaluation can be set in config.py | |||
| step: 598, loss is 0.19958496, fps is 57.95493929352674 | |||
| step: 599, loss is 0.18371582, fps is 58.04039977720966 | |||
| step: 600, loss is 0.22070312, fps is 56.99692546024671 | |||
| ``` | |||
| The model checkpoint will be saved in the current directory. | |||
| The model checkpoint will be saved in the current directory. | |||
| ### Distributed Training | |||
| ``` | |||
| ```shell | |||
| bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [DATASET] | |||
| ``` | |||
| The above shell script will run distribute training in the background. You can view the results through the file `logs/device[X]/log.log`. The loss value will be achieved as follows: | |||
| ``` | |||
| ```shell | |||
| # grep "loss is" logs/device0/log.log | |||
| step: 1, loss is 0.70524895, fps is 0.15914689861221412 | |||
| step: 2, loss is 0.6925452, fps is 56.43668656967454 | |||
| @@ -191,27 +184,27 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329 | |||
| - evaluation on ISBI dataset when running on Ascend | |||
| Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/unet/ckpt_unet_medical_adam-48_600.ckpt". | |||
| ``` | |||
| python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 & | |||
| ```shell | |||
| python eval.py --data_url=/path/to/data/ --ckpt_path=/path/to/checkpoint/ > eval.log 2>&1 & | |||
| OR | |||
| bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT] | |||
| ``` | |||
| The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows: | |||
| ``` | |||
| ```shell | |||
| # grep "Cross valid dice coeff is:" eval.log | |||
| ============== Cross valid dice coeff is: {'dice_coeff': 0.9085704886070473} | |||
| ``` | |||
| ``` | |||
| # [Model Description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Evaluation Performance | |||
| ## Performance | |||
| ### Evaluation Performance | |||
| | Parameters | Ascend | | |||
| | -------------------------- | ------------------------------------------------------------ | | |||
| @@ -227,45 +220,74 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329 | |||
| | outputs | probability | | |||
| | Loss | 0.22070312 | | |||
| | Speed | 1pc: 267 ms/step; 8pc: 280 ms/step; | | |||
| | Total time | 1pc: 2.67 mins; 8pc: 1.40 mins | | |||
| | Total time | 1pc: 2.67 mins; 8pc: 1.40 mins | | |||
| | Parameters (M) | 93M | | |||
| | Checkpoint for Fine tuning | 355.11M (.ckpt file) | | |||
| | Scripts | [unet script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/unet) | | |||
| ## [How to use](#contents) | |||
| ### Inference | |||
| If you need to use the trained model to perform inference on multiple hardware platforms, such as Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example: | |||
| - Running on Ascend | |||
| ``` | |||
| ```python | |||
| # Set context | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",save_graphs=True,device_id=device_id) | |||
| # Load unseen dataset for inference | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) | |||
| # Define model and Load pre-trained model | |||
| net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||
| param_dict= load_checkpoint(ckpt_path) | |||
| load_param_into_net(net , param_dict) | |||
| criterion = CrossEntropyWithLogits() | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| # Make predictions on the unseen dataset | |||
| print("============== Starting Evaluating ============") | |||
| dice_score = model.eval(valid_dataset, dataset_sink_mode=False) | |||
| print("============== Cross valid dice coeff is:", dice_score) | |||
| print("============== Cross valid dice coeff is:", dice_score) | |||
| ``` | |||
| ### Continue Training on the Pretrained Model | |||
| - Running on Ascend 310 | |||
| - running on Ascend | |||
| Export MindIR | |||
| ```shell | |||
| python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] | |||
| ``` | |||
| The ckpt_file parameter is required, | |||
| `EXPORT_FORMAT` should be in ["AIR", "MINDIR"] | |||
| Before performing inference, the MINDIR file must be exported by export script on the 910 environment. | |||
| Current batch_size can only be set to 1. | |||
| ```shell | |||
| # Ascend310 inference | |||
| bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] | |||
| ``` | |||
| `DEVICE_ID` is optional, default value is 0. | |||
| Inference result is saved in current path, you can find result in acc.log file. | |||
| ```text | |||
| Cross valid dice coeff is: 0.9054352151297033 | |||
| ``` | |||
| ### Continue Training on the Pretrained Model | |||
| - running on Ascend | |||
| ```python | |||
| # Define model | |||
| net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes']) | |||
| # Continue training if set 'resume' to be True | |||
| @@ -276,33 +298,32 @@ If you need to use the trained model to perform inference on multiple hardware p | |||
| # Load dataset | |||
| train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute) | |||
| train_data_size = train_dataset.get_dataset_size() | |||
| optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=cfg['weight_decay'], | |||
| loss_scale=cfg['loss_scale']) | |||
| criterion = CrossEntropyWithLogits() | |||
| loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(cfg['FixedLossScaleManager'], False) | |||
| model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3") | |||
| # Set callbacks | |||
| # Set callbacks | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size, | |||
| keep_checkpoint_max=cfg['keep_checkpoint_max']) | |||
| ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam', | |||
| directory='./ckpt_{}/'.format(device_id), | |||
| config=ckpt_config) | |||
| print("============== Starting Training ==============") | |||
| model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb], | |||
| dataset_sink_mode=False) | |||
| print("============== End Training ==============") | |||
| ``` | |||
| # [Description of Random Situation](#contents) | |||
| In data_loader.py, we set the seed inside “_get_val_train_indices" function. We also use random seed in train.py. | |||
| In data_loader.py, we set the seed inside “_get_val_train_indices" function. We also use random seed in train.py. | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -254,6 +254,33 @@ step: 300, loss is 0.18949677, fps is 57.63118508760329 | |||
| print("============== Starting Evaluating ============") | |||
| dice_score = model.eval(valid_dataset, dataset_sink_mode=False) | |||
| print("============== Cross valid dice coeff is:", dice_score) | |||
| ``` | |||
| - Ascend 310环境运行 | |||
| 导出mindir模型 | |||
| ```shell | |||
| python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT] | |||
| ``` | |||
| 参数`ckpt_file` 是必需的,`EXPORT_FORMAT` 必须在 ["AIR", "MINDIR"]中进行选择。 | |||
| 在执行推理前,MINDIR文件必须在910上通过export.py文件导出。 | |||
| 目前仅可处理batch_Size为1。 | |||
| ```shell | |||
| # Ascend310 推理 | |||
| bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] | |||
| ``` | |||
| `DEVICE_ID` 可选,默认值为 0。 | |||
| 推理结果保存在当前路径,可在acc.log中看到最终精度结果。 | |||
| ```text | |||
| Cross valid dice coeff is: 0.9054352151297033 | |||
| ``` | |||
| ### 继续训练预训练模型 | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_INFERENCE_UTILS_H_ | |||
| #define MINDSPORE_INFERENCE_UTILS_H_ | |||
| #include <sys/stat.h> | |||
| #include <dirent.h> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "include/api/types.h" | |||
| std::vector<std::string> GetAllFiles(std::string_view dirName); | |||
| DIR *OpenDir(std::string_view dirName); | |||
| std::string RealPath(std::string_view path); | |||
| mindspore::MSTensor ReadFileToTensor(const std::string &file); | |||
| int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs); | |||
| #endif | |||
| @@ -0,0 +1,14 @@ | |||
| cmake_minimum_required(VERSION 3.14.1) | |||
| project(MindSporeCxxTestcase[CXX]) | |||
| add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") | |||
| set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) | |||
| option(MINDSPORE_PATH "mindspore install path" "") | |||
| include_directories(${MINDSPORE_PATH}) | |||
| include_directories(${MINDSPORE_PATH}/include) | |||
| include_directories(${PROJECT_SRC_ROOT}/../inc) | |||
| find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) | |||
| file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) | |||
| add_executable(main main.cc utils.cc) | |||
| target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) | |||
| @@ -0,0 +1,18 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| cmake . -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" | |||
| make | |||
| @@ -0,0 +1,123 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <sys/time.h> | |||
| #include <gflags/gflags.h> | |||
| #include <dirent.h> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <iosfwd> | |||
| #include <vector> | |||
| #include <fstream> | |||
| #include "include/api/model.h" | |||
| #include "include/api/serialization.h" | |||
| #include "include/api/context.h" | |||
| #include "include/minddata/dataset/include/execute.h" | |||
| #include "include/minddata/dataset/include/vision.h" | |||
| #include "../inc/utils.h" | |||
| #include "include/api/types.h" | |||
| using mindspore::Context; | |||
| using mindspore::GlobalContext; | |||
| using mindspore::ModelContext; | |||
| using mindspore::Serialization; | |||
| using mindspore::Model; | |||
| using mindspore::Status; | |||
| using mindspore::dataset::Execute; | |||
| using mindspore::MSTensor; | |||
| using mindspore::ModelType; | |||
| using mindspore::GraphCell; | |||
| using mindspore::kSuccess; | |||
| DEFINE_string(mindir_path, "", "mindir path"); | |||
| DEFINE_string(dataset_path, ".", "dataset path"); | |||
| DEFINE_int32(device_id, 0, "device id"); | |||
| int main(int argc, char **argv) { | |||
| gflags::ParseCommandLineFlags(&argc, &argv, true); | |||
| if (RealPath(FLAGS_mindir_path).empty()) { | |||
| std::cout << "Invalid mindir" << std::endl; | |||
| return 1; | |||
| } | |||
| GlobalContext::SetGlobalDeviceTarget(mindspore::kDeviceTypeAscend310); | |||
| GlobalContext::SetGlobalDeviceID(FLAGS_device_id); | |||
| auto graph = Serialization::LoadModel(FLAGS_mindir_path, ModelType::kMindIR); | |||
| auto model_context = std::make_shared<Context>(); | |||
| Model model(GraphCell(graph), model_context); | |||
| Status ret = model.Build(); | |||
| if (ret != kSuccess) { | |||
| std::cout << "EEEEEEEERROR Build failed." << std::endl; | |||
| return 1; | |||
| } | |||
| std::vector<MSTensor> model_inputs = model.GetInputs(); | |||
| auto all_files = GetAllFiles(FLAGS_dataset_path); | |||
| if (all_files.empty()) { | |||
| std::cout << "ERROR: no input data." << std::endl; | |||
| return 1; | |||
| } | |||
| std::map<double, double> costTime_map; | |||
| size_t size = all_files.size(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| struct timeval start = {0}; | |||
| struct timeval end = {0}; | |||
| double startTime_ms; | |||
| double endTime_ms; | |||
| std::vector<MSTensor> inputs; | |||
| std::vector<MSTensor> outputs; | |||
| std::cout << "Start predict input files:" << all_files[i] << std::endl; | |||
| auto img = ReadFileToTensor(all_files[i]); | |||
| inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), | |||
| img.Data().get(), img.DataSize()); | |||
| gettimeofday(&start, NULL); | |||
| ret = model.Predict(inputs, &outputs); | |||
| gettimeofday(&end, NULL); | |||
| if (ret != kSuccess) { | |||
| std::cout << "Predict " << all_files[i] << " failed." << std::endl; | |||
| return 1; | |||
| } | |||
| startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; | |||
| endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; | |||
| costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms)); | |||
| WriteResult(all_files[i], outputs); | |||
| } | |||
| double average = 0.0; | |||
| int infer_cnt = 0; | |||
| char tmpCh[256] = {0}; | |||
| for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { | |||
| double diff = 0.0; | |||
| diff = iter->second - iter->first; | |||
| average += diff; | |||
| infer_cnt++; | |||
| } | |||
| average = average/infer_cnt; | |||
| snprintf(tmpCh, sizeof(tmpCh), "NN inference cost average time: %4.3f ms of infer_count %d \n", average, infer_cnt); | |||
| std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl; | |||
| std::string file_name = "./time_Result" + std::string("/test_perform_static.txt"); | |||
| std::ofstream file_stream(file_name.c_str(), std::ios::trunc); | |||
| file_stream << tmpCh; | |||
| file_stream.close(); | |||
| costTime_map.clear(); | |||
| return 0; | |||
| } | |||
| @@ -0,0 +1,136 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "../inc/utils.h" | |||
| #include <fstream> | |||
| #include <algorithm> | |||
| #include <iostream> | |||
| using mindspore::MSTensor; | |||
| using mindspore::DataType; | |||
| std::vector<std::string> GetAllFiles(std::string_view dirName) { | |||
| struct dirent *filename; | |||
| DIR *dir = OpenDir(dirName); | |||
| if (dir == nullptr) { | |||
| return {}; | |||
| } | |||
| std::vector<std::string> res; | |||
| while ((filename = readdir(dir)) != nullptr) { | |||
| std::string dName = std::string(filename->d_name); | |||
| if (dName == "." || | |||
| dName == ".." || | |||
| filename->d_type != DT_REG) | |||
| continue; | |||
| res.emplace_back(std::string(dirName) + "/" + filename->d_name); | |||
| } | |||
| std::sort(res.begin(), res.end()); | |||
| for (auto &f : res) { | |||
| std::cout << "image file: " << f << std::endl; | |||
| } | |||
| return res; | |||
| } | |||
| int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) { | |||
| std::string homePath = "./result_Files"; | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| size_t outputSize; | |||
| std::shared_ptr<const void> netOutput; | |||
| netOutput = outputs[i].Data(); | |||
| outputSize = outputs[i].DataSize(); | |||
| int pos = imageFile.rfind('/'); | |||
| std::string fileName(imageFile, pos + 1); | |||
| fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); | |||
| std::string outFileName = homePath + "/" + fileName; | |||
| FILE * outputFile = fopen(outFileName.c_str(), "wb"); | |||
| fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); | |||
| fclose(outputFile); | |||
| outputFile = nullptr; | |||
| } | |||
| return 0; | |||
| } | |||
| MSTensor ReadFileToTensor(const std::string &file) { | |||
| if (file.empty()) { | |||
| std::cout << "Pointer file is nullptr" << std::endl; | |||
| return MSTensor(); | |||
| } | |||
| std::ifstream ifs(file); | |||
| if (!ifs.good()) { | |||
| std::cout << "File: " << file << " is not exist" << std::endl; | |||
| return MSTensor(); | |||
| } | |||
| if (!ifs.is_open()) { | |||
| std::cout << "File: " << file << "open failed" << std::endl; | |||
| return MSTensor(); | |||
| } | |||
| ifs.seekg(0, std::ios::end); | |||
| size_t size = ifs.tellg(); | |||
| MSTensor buffer(file, DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size); | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size); | |||
| ifs.close(); | |||
| return buffer; | |||
| } | |||
| DIR *OpenDir(std::string_view dirName) { | |||
| if (dirName.empty()) { | |||
| std::cout << " dirName is null ! " << std::endl; | |||
| return nullptr; | |||
| } | |||
| std::string realPath = RealPath(dirName); | |||
| struct stat s; | |||
| lstat(realPath.c_str(), &s); | |||
| if (!S_ISDIR(s.st_mode)) { | |||
| std::cout << "dirName is not a valid directory !" << std::endl; | |||
| return nullptr; | |||
| } | |||
| DIR *dir; | |||
| dir = opendir(realPath.c_str()); | |||
| if (dir == nullptr) { | |||
| std::cout << "Can not open dir " << dirName << std::endl; | |||
| return nullptr; | |||
| } | |||
| std::cout << "Successfully opened the dir " << dirName << std::endl; | |||
| return dir; | |||
| } | |||
| std::string RealPath(std::string_view path) { | |||
| char real_path_mem[PATH_MAX] = {0}; | |||
| char *real_path_ret = nullptr; | |||
| real_path_ret = realpath(path.data(), real_path_mem); | |||
| if (real_path_ret == nullptr) { | |||
| std::cout << "File: " << path << " is not exist."; | |||
| return ""; | |||
| } | |||
| std::string real_path(real_path_mem); | |||
| std::cout << path << " realpath is: " << real_path << std::endl; | |||
| return real_path; | |||
| } | |||
| @@ -0,0 +1,97 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """unet 310 infer.""" | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| from src.data_loader import create_dataset | |||
| from src.config import cfg_unet | |||
| from scipy.special import softmax | |||
| class dice_coeff(): | |||
| def __init__(self): | |||
| self.clear() | |||
| def clear(self): | |||
| self._dice_coeff_sum = 0 | |||
| self._samples_num = 0 | |||
| def update(self, *inputs): | |||
| if len(inputs) != 2: | |||
| raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||
| y_pred = inputs[0] | |||
| y = np.array(inputs[1]) | |||
| self._samples_num += y.shape[0] | |||
| y_pred = y_pred.transpose(0, 2, 3, 1) | |||
| y = y.transpose(0, 2, 3, 1) | |||
| y_pred = softmax(y_pred, axis=3) | |||
| inter = np.dot(y_pred.flatten(), y.flatten()) | |||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | |||
| single_dice_coeff = 2*float(inter)/float(union+1e-6) | |||
| print("single dice coeff is:", single_dice_coeff) | |||
| self._dice_coeff_sum += single_dice_coeff | |||
| def eval(self): | |||
| if self._samples_num == 0: | |||
| raise RuntimeError('Total samples num must not be 0.') | |||
| return self._dice_coeff_sum / float(self._samples_num) | |||
| def test_net(data_dir, | |||
| cross_valid_ind=1, | |||
| cfg=None): | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) | |||
| labels_list = [] | |||
| for data in valid_dataset: | |||
| labels_list.append(data[1].asnumpy()) | |||
| return labels_list | |||
| def get_args(): | |||
| parser = argparse.ArgumentParser(description='Test the UNet on images and target masks', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', | |||
| help='data directory') | |||
| parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/', | |||
| help='infer result path') | |||
| return parser.parse_args() | |||
| if __name__ == '__main__': | |||
| args = get_args() | |||
| label_list = test_net(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet) | |||
| rst_path = args.rst_path | |||
| metrics = dice_coeff() | |||
| for j in range(len(os.listdir(rst_path))): | |||
| file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" | |||
| output = np.fromfile(file_name, np.float32).reshape(1, 2, 388, 388) | |||
| label = label_list[j] | |||
| metrics.update(output, label) | |||
| print("Cross valid dice coeff is: ", metrics.eval()) | |||
| @@ -0,0 +1,45 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """unet 310 infer preprocess dataset""" | |||
| import argparse | |||
| from src.data_loader import create_dataset | |||
| from src.config import cfg_unet | |||
| def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None): | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False) | |||
| for i, data in enumerate(valid_dataset): | |||
| file_name = "ISBI_test_bs_1_" + str(i) + ".bin" | |||
| file_path = result_path + file_name | |||
| data[0].asnumpy().tofile(file_path) | |||
| def get_args(): | |||
| parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', | |||
| help='data directory') | |||
| parser.add_argument('-p', '--result_path', dest='result_path', type=str, default='./preprocess_Result/', | |||
| help='result path') | |||
| return parser.parse_args() | |||
| if __name__ == '__main__': | |||
| args = get_args() | |||
| preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, result_path= | |||
| args.result_path) | |||
| @@ -0,0 +1,115 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [[ $# -lt 2 || $# -gt 3 ]]; then | |||
| echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] | |||
| DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| model=$(get_real_path $1) | |||
| data_path=$(get_real_path $2) | |||
| if [ $# == 3 ]; then | |||
| device_id=$3 | |||
| if [ -z $device_id ]; then | |||
| device_id=0 | |||
| else | |||
| device_id=$device_id | |||
| fi | |||
| fi | |||
| echo "mindir name: "$model | |||
| echo "dataset path: "$data_path | |||
| echo "device id: "$device_id | |||
| export ASCEND_HOME=/usr/local/Ascend/ | |||
| if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then | |||
| export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH | |||
| export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH | |||
| export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe | |||
| export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH | |||
| export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp | |||
| else | |||
| export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH | |||
| export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH | |||
| export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages/te.egg:$ASCEND_HOME/atc/python/site-packages/topi.egg:$ASCEND_HOME/atc/python/site-packages/auto_tune.egg::$ASCEND_HOME/atc/python/site-packages/schedule_search.egg:$PYTHONPATH | |||
| export ASCEND_OPP_PATH=$ASCEND_HOME/opp | |||
| fi | |||
| function preprocess_data() | |||
| { | |||
| if [ -d preprocess_Result ]; then | |||
| rm -rf ./preprocess_Result | |||
| fi | |||
| mkdir preprocess_Result | |||
| python3.7 ../preprocess.py --data_url=$data_path --result_path=./preprocess_Result/ | |||
| } | |||
| function compile_app() | |||
| { | |||
| cd ../ascend310_infer/src | |||
| if [ -f "Makefile" ]; then | |||
| make clean | |||
| fi | |||
| sh build.sh &> build.log | |||
| } | |||
| function infer() | |||
| { | |||
| cd - | |||
| if [ -d result_Files ]; then | |||
| rm -rf ./result_Files | |||
| fi | |||
| if [ -d time_Result ]; then | |||
| rm -rf ./time_Result | |||
| fi | |||
| mkdir result_Files | |||
| mkdir time_Result | |||
| ../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id &> infer.log | |||
| } | |||
| function cal_acc() | |||
| { | |||
| python3.7 ../postprocess.py --data_url=$data_path --rst_path=./result_Files/ &> acc.log & | |||
| } | |||
| preprocess_data | |||
| if [ $? -ne 0 ]; then | |||
| echo "preprocess dataset failed" | |||
| exit 1 | |||
| fi | |||
| compile_app | |||
| if [ $? -ne 0 ]; then | |||
| echo "compile app code failed" | |||
| exit 1 | |||
| fi | |||
| infer | |||
| if [ $? -ne 0 ]; then | |||
| echo "execute inference failed" | |||
| exit 1 | |||
| fi | |||
| cal_acc | |||
| if [ $? -ne 0 ]; then | |||
| echo "calculate accuracy failed" | |||
| exit 1 | |||
| fi | |||