| @@ -42,11 +42,17 @@ using mindspore::MSTensor; | |||||
| using mindspore::ModelType; | using mindspore::ModelType; | ||||
| using mindspore::GraphCell; | using mindspore::GraphCell; | ||||
| using mindspore::kSuccess; | using mindspore::kSuccess; | ||||
| using mindspore::dataset::vision::Decode; | |||||
| using mindspore::dataset::vision::SwapRedBlue; | |||||
| using mindspore::dataset::vision::Normalize; | |||||
| using mindspore::dataset::vision::Resize; | |||||
| using mindspore::dataset::vision::HWC2CHW; | |||||
| DEFINE_string(mindir_path, "", "mindir path"); | DEFINE_string(mindir_path, "", "mindir path"); | ||||
| DEFINE_string(dataset_path, ".", "dataset path"); | DEFINE_string(dataset_path, ".", "dataset path"); | ||||
| DEFINE_int32(device_id, 0, "device id"); | DEFINE_int32(device_id, 0, "device id"); | ||||
| DEFINE_string(need_preprocess, "n", "need preprocess or not"); | |||||
| int main(int argc, char **argv) { | int main(int argc, char **argv) { | ||||
| gflags::ParseCommandLineFlags(&argc, &argv, true); | gflags::ParseCommandLineFlags(&argc, &argv, true); | ||||
| @@ -78,6 +84,14 @@ int main(int argc, char **argv) { | |||||
| std::map<double, double> costTime_map; | std::map<double, double> costTime_map; | ||||
| size_t size = all_files.size(); | size_t size = all_files.size(); | ||||
| auto decode(new Decode()); | |||||
| auto swapredblue(new SwapRedBlue()); | |||||
| auto resize(new Resize({96, 96})); | |||||
| auto normalize(new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5})); | |||||
| auto hwc2chw(new HWC2CHW()); | |||||
| Execute preprocess({decode, swapredblue, resize, normalize, hwc2chw}); | |||||
| for (size_t i = 0; i < size; ++i) { | for (size_t i = 0; i < size; ++i) { | ||||
| struct timeval start = {0}; | struct timeval start = {0}; | ||||
| struct timeval end = {0}; | struct timeval end = {0}; | ||||
| @@ -86,7 +100,17 @@ int main(int argc, char **argv) { | |||||
| std::vector<MSTensor> inputs; | std::vector<MSTensor> inputs; | ||||
| std::vector<MSTensor> outputs; | std::vector<MSTensor> outputs; | ||||
| std::cout << "Start predict input files:" << all_files[i] << std::endl; | std::cout << "Start predict input files:" << all_files[i] << std::endl; | ||||
| auto img = ReadFileToTensor(all_files[i]); | |||||
| auto img = MSTensor(); | |||||
| if (FLAGS_need_preprocess == "y") { | |||||
| ret = preprocess(ReadFileToTensor(all_files[i]), &img); | |||||
| if (ret != kSuccess) { | |||||
| std::cout << "preprocess " << all_files[i] << " failed." << std::endl; | |||||
| return 1; | |||||
| } | |||||
| } else { | |||||
| img = ReadFileToTensor(all_files[i]); | |||||
| } | |||||
| inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), | inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), | ||||
| img.Data().get(), img.DataSize()); | img.Data().get(), img.DataSize()); | ||||
| @@ -21,6 +21,7 @@ from mindspore import Tensor, export, load_checkpoint, load_param_into_net, cont | |||||
| from src.unet_medical.unet_model import UNetMedical | from src.unet_medical.unet_model import UNetMedical | ||||
| from src.unet_nested import NestedUNet, UNet | from src.unet_nested import NestedUNet, UNet | ||||
| from src.config import cfg_unet as cfg | from src.config import cfg_unet as cfg | ||||
| from src.utils import UnetEval | |||||
| parser = argparse.ArgumentParser(description='unet export') | parser = argparse.ArgumentParser(description='unet export') | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | parser.add_argument("--device_id", type=int, default=0, help="Device id") | ||||
| @@ -52,5 +53,6 @@ if __name__ == "__main__": | |||||
| param_dict = load_checkpoint(args.ckpt_file) | param_dict = load_checkpoint(args.ckpt_file) | ||||
| # load the parameter into net | # load the parameter into net | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net = UnetEval(net) | |||||
| input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32)) | input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32)) | ||||
| export(net, input_data, file_name=args.file_name, file_format=args.file_format) | export(net, input_data, file_name=args.file_name, file_format=args.file_format) | ||||
| @@ -15,53 +15,77 @@ | |||||
| """unet 310 infer.""" | """unet 310 infer.""" | ||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| from src.data_loader import create_dataset | |||||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||||
| from src.config import cfg_unet | from src.config import cfg_unet | ||||
| from scipy.special import softmax | |||||
| class dice_coeff(): | class dice_coeff(): | ||||
| def __init__(self): | def __init__(self): | ||||
| self.clear() | self.clear() | ||||
| def clear(self): | def clear(self): | ||||
| self._dice_coeff_sum = 0 | self._dice_coeff_sum = 0 | ||||
| self._iou_sum = 0 | |||||
| self._samples_num = 0 | self._samples_num = 0 | ||||
| def update(self, *inputs): | def update(self, *inputs): | ||||
| if len(inputs) != 2: | if len(inputs) != 2: | ||||
| raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||||
| y_pred = inputs[0] | |||||
| raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) | |||||
| y = np.array(inputs[1]) | y = np.array(inputs[1]) | ||||
| self._samples_num += y.shape[0] | self._samples_num += y.shape[0] | ||||
| y_pred = y_pred.transpose(0, 2, 3, 1) | |||||
| y = y.transpose(0, 2, 3, 1) | y = y.transpose(0, 2, 3, 1) | ||||
| y_pred = softmax(y_pred, axis=3) | |||||
| b, h, w, c = y.shape | |||||
| if b != 1: | |||||
| raise ValueError('Batch size should be 1 when in evaluation.') | |||||
| y = y.reshape((h, w, c)) | |||||
| if cfg_unet["eval_activate"].lower() == "softmax": | |||||
| y_softmax = np.squeeze(inputs[0][0], axis=0) | |||||
| if cfg_unet["eval_resize"]: | |||||
| y_pred = [] | |||||
| for m in range(cfg_unet["num_classes"]): | |||||
| y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255) | |||||
| y_pred = np.stack(y_pred, axis=-1) | |||||
| else: | |||||
| y_pred = y_softmax | |||||
| elif cfg_unet["eval_activate"].lower() == "argmax": | |||||
| y_argmax = np.squeeze(inputs[0][1], axis=0) | |||||
| y_pred = [] | |||||
| for n in range(cfg_unet["num_classes"]): | |||||
| if cfg_unet["eval_resize"]: | |||||
| y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST)) | |||||
| else: | |||||
| y_pred.append(np.float32(y_argmax == n)) | |||||
| y_pred = np.stack(y_pred, axis=-1) | |||||
| else: | |||||
| raise ValueError('config eval_activate should be softmax or argmax.') | |||||
| y_pred = y_pred.astype(np.float32) | |||||
| inter = np.dot(y_pred.flatten(), y.flatten()) | inter = np.dot(y_pred.flatten(), y.flatten()) | ||||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | ||||
| single_dice_coeff = 2*float(inter)/float(union+1e-6) | single_dice_coeff = 2*float(inter)/float(union+1e-6) | ||||
| print("single dice coeff is:", single_dice_coeff) | |||||
| single_iou = single_dice_coeff / (2 - single_dice_coeff) | |||||
| print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) | |||||
| self._dice_coeff_sum += single_dice_coeff | self._dice_coeff_sum += single_dice_coeff | ||||
| self._iou_sum += single_iou | |||||
| def eval(self): | def eval(self): | ||||
| if self._samples_num == 0: | if self._samples_num == 0: | ||||
| raise RuntimeError('Total samples num must not be 0.') | raise RuntimeError('Total samples num must not be 0.') | ||||
| return self._dice_coeff_sum / float(self._samples_num) | |||||
| return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) | |||||
| def test_net(data_dir, | def test_net(data_dir, | ||||
| cross_valid_ind=1, | cross_valid_ind=1, | ||||
| cfg=None): | cfg=None): | ||||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], | |||||
| img_size=cfg['img_size']) | |||||
| if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": | |||||
| valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, | |||||
| eval_resize=cfg["eval_resize"], split=0.8) | |||||
| else: | |||||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], | |||||
| img_size=cfg['img_size']) | |||||
| labels_list = [] | labels_list = [] | ||||
| for data in valid_dataset: | for data in valid_dataset: | ||||
| @@ -89,10 +113,25 @@ if __name__ == '__main__': | |||||
| rst_path = args.rst_path | rst_path = args.rst_path | ||||
| metrics = dice_coeff() | metrics = dice_coeff() | ||||
| for j in range(len(os.listdir(rst_path))): | |||||
| file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" | |||||
| output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576) | |||||
| label = label_list[j] | |||||
| metrics.update(output, label) | |||||
| print("Cross valid dice coeff is: ", metrics.eval()) | |||||
| if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": | |||||
| for i, bin_name in enumerate(os.listdir('./preprocess_Result/')): | |||||
| bin_name_softmax = bin_name.replace(".png", "") + "_0.bin" | |||||
| bin_name_argmax = bin_name.replace(".png", "") + "_1.bin" | |||||
| file_name_sof = rst_path + bin_name_softmax | |||||
| file_name_arg = rst_path + bin_name_argmax | |||||
| softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2) | |||||
| argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96) | |||||
| label = label_list[i] | |||||
| metrics.update((softmax_out, argmax_out), label) | |||||
| else: | |||||
| for j in range(len(os.listdir('./preprocess_Result/'))): | |||||
| file_name_sof = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" | |||||
| file_name_arg = rst_path + "ISBI_test_bs_1_" + str(j) + "_1" + ".bin" | |||||
| softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 576, 576, 2) | |||||
| argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 576, 576) | |||||
| label = label_list[j] | |||||
| metrics.update((softmax_out, argmax_out), label) | |||||
| eval_score = metrics.eval() | |||||
| print("============== Cross valid dice coeff is:", eval_score[0]) | |||||
| print("============== Cross valid IOU is:", eval_score[1]) | |||||
| @@ -14,6 +14,10 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """unet 310 infer preprocess dataset""" | """unet 310 infer preprocess dataset""" | ||||
| import argparse | import argparse | ||||
| import os | |||||
| import numpy as np | |||||
| import cv2 | |||||
| from src.data_loader import create_dataset | from src.data_loader import create_dataset | ||||
| from src.config import cfg_unet | from src.config import cfg_unet | ||||
| @@ -29,6 +33,56 @@ def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None): | |||||
| data[0].asnumpy().tofile(file_path) | data[0].asnumpy().tofile(file_path) | ||||
| class CellNucleiDataset: | |||||
| """ | |||||
| Cell nuclei dataset preprocess class. | |||||
| """ | |||||
| def __init__(self, data_dir, repeat, result_path, is_train=False, split=0.8): | |||||
| self.data_dir = data_dir | |||||
| self.img_ids = sorted(next(os.walk(self.data_dir))[1]) | |||||
| self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat | |||||
| np.random.shuffle(self.train_ids) | |||||
| self.val_ids = self.img_ids[int(len(self.img_ids) * split):] | |||||
| self.is_train = is_train | |||||
| self.result_path = result_path | |||||
| self._preprocess_dataset() | |||||
| def _preprocess_dataset(self): | |||||
| for img_id in self.val_ids: | |||||
| path = os.path.join(self.data_dir, img_id) | |||||
| img = cv2.imread(os.path.join(path, "images", img_id + ".png")) | |||||
| if len(img.shape) == 2: | |||||
| img = np.expand_dims(img, axis=-1) | |||||
| img = np.concatenate([img, img, img], axis=-1) | |||||
| mask = [] | |||||
| for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: | |||||
| mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) | |||||
| mask.append(mask_) | |||||
| mask = np.max(mask, axis=0) | |||||
| cv2.imwrite(os.path.join(self.result_path, img_id + ".png"), img) | |||||
| def _read_img_mask(self, img_id): | |||||
| path = os.path.join(self.data_dir, img_id) | |||||
| img = cv2.imread(os.path.join(path, "image.png")) | |||||
| mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) | |||||
| return img, mask | |||||
| def __getitem__(self, index): | |||||
| if self.is_train: | |||||
| return self._read_img_mask(self.train_ids[index]) | |||||
| return self._read_img_mask(self.val_ids[index]) | |||||
| @property | |||||
| def column_names(self): | |||||
| column_names = ['image', 'mask'] | |||||
| return column_names | |||||
| def __len__(self): | |||||
| if self.is_train: | |||||
| return len(self.train_ids) | |||||
| return len(self.val_ids) | |||||
| def get_args(): | def get_args(): | ||||
| parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ', | parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ', | ||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
| @@ -42,5 +96,8 @@ def get_args(): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| args = get_args() | args = get_args() | ||||
| preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, result_path= | |||||
| args.result_path) | |||||
| if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": | |||||
| cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8) | |||||
| else: | |||||
| preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, | |||||
| result_path=args.result_path) | |||||
| @@ -14,9 +14,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [[ $# -lt 2 || $# -gt 3 ]]; then | |||||
| echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] | |||||
| DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" | |||||
| if [[ $# -lt 3 || $# -gt 4 ]]; then | |||||
| echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [NEED_PREPROCESS] | |||||
| DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero. | |||||
| NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'." | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -29,7 +30,7 @@ get_real_path(){ | |||||
| } | } | ||||
| model=$(get_real_path $1) | model=$(get_real_path $1) | ||||
| data_path=$(get_real_path $2) | data_path=$(get_real_path $2) | ||||
| if [ $# == 3 ]; then | |||||
| if [ $# == 4 ]; then | |||||
| device_id=$3 | device_id=$3 | ||||
| if [ -z $device_id ]; then | if [ -z $device_id ]; then | ||||
| device_id=0 | device_id=0 | ||||
| @@ -37,10 +38,12 @@ if [ $# == 3 ]; then | |||||
| device_id=$device_id | device_id=$device_id | ||||
| fi | fi | ||||
| fi | fi | ||||
| need_preprocess=$4 | |||||
| echo "mindir name: "$model | echo "mindir name: "$model | ||||
| echo "dataset path: "$data_path | echo "dataset path: "$data_path | ||||
| echo "device id: "$device_id | echo "device id: "$device_id | ||||
| echo "need preprocess or not: "$need_preprocess | |||||
| export ASCEND_HOME=/usr/local/Ascend/ | export ASCEND_HOME=/usr/local/Ascend/ | ||||
| if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then | if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then | ||||
| @@ -85,7 +88,7 @@ function infer() | |||||
| fi | fi | ||||
| mkdir result_Files | mkdir result_Files | ||||
| mkdir time_Result | mkdir time_Result | ||||
| ../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id &> infer.log | |||||
| ../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id --need_preprocess=$need_preprocess &> infer.log | |||||
| } | } | ||||
| function cal_acc() | function cal_acc() | ||||