| @@ -42,11 +42,17 @@ using mindspore::MSTensor; | |||
| using mindspore::ModelType; | |||
| using mindspore::GraphCell; | |||
| using mindspore::kSuccess; | |||
| using mindspore::dataset::vision::Decode; | |||
| using mindspore::dataset::vision::SwapRedBlue; | |||
| using mindspore::dataset::vision::Normalize; | |||
| using mindspore::dataset::vision::Resize; | |||
| using mindspore::dataset::vision::HWC2CHW; | |||
| DEFINE_string(mindir_path, "", "mindir path"); | |||
| DEFINE_string(dataset_path, ".", "dataset path"); | |||
| DEFINE_int32(device_id, 0, "device id"); | |||
| DEFINE_string(need_preprocess, "n", "need preprocess or not"); | |||
| int main(int argc, char **argv) { | |||
| gflags::ParseCommandLineFlags(&argc, &argv, true); | |||
| @@ -78,6 +84,14 @@ int main(int argc, char **argv) { | |||
| std::map<double, double> costTime_map; | |||
| size_t size = all_files.size(); | |||
| auto decode(new Decode()); | |||
| auto swapredblue(new SwapRedBlue()); | |||
| auto resize(new Resize({96, 96})); | |||
| auto normalize(new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5})); | |||
| auto hwc2chw(new HWC2CHW()); | |||
| Execute preprocess({decode, swapredblue, resize, normalize, hwc2chw}); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| struct timeval start = {0}; | |||
| struct timeval end = {0}; | |||
| @@ -86,7 +100,17 @@ int main(int argc, char **argv) { | |||
| std::vector<MSTensor> inputs; | |||
| std::vector<MSTensor> outputs; | |||
| std::cout << "Start predict input files:" << all_files[i] << std::endl; | |||
| auto img = ReadFileToTensor(all_files[i]); | |||
| auto img = MSTensor(); | |||
| if (FLAGS_need_preprocess == "y") { | |||
| ret = preprocess(ReadFileToTensor(all_files[i]), &img); | |||
| if (ret != kSuccess) { | |||
| std::cout << "preprocess " << all_files[i] << " failed." << std::endl; | |||
| return 1; | |||
| } | |||
| } else { | |||
| img = ReadFileToTensor(all_files[i]); | |||
| } | |||
| inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), | |||
| img.Data().get(), img.DataSize()); | |||
| @@ -21,6 +21,7 @@ from mindspore import Tensor, export, load_checkpoint, load_param_into_net, cont | |||
| from src.unet_medical.unet_model import UNetMedical | |||
| from src.unet_nested import NestedUNet, UNet | |||
| from src.config import cfg_unet as cfg | |||
| from src.utils import UnetEval | |||
| parser = argparse.ArgumentParser(description='unet export') | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | |||
| @@ -52,5 +53,6 @@ if __name__ == "__main__": | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| # load the parameter into net | |||
| load_param_into_net(net, param_dict) | |||
| net = UnetEval(net) | |||
| input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32)) | |||
| export(net, input_data, file_name=args.file_name, file_format=args.file_format) | |||
| @@ -15,53 +15,77 @@ | |||
| """unet 310 infer.""" | |||
| import os | |||
| import argparse | |||
| import cv2 | |||
| import numpy as np | |||
| from src.data_loader import create_dataset | |||
| from src.data_loader import create_dataset, create_cell_nuclei_dataset | |||
| from src.config import cfg_unet | |||
| from scipy.special import softmax | |||
| class dice_coeff(): | |||
| def __init__(self): | |||
| self.clear() | |||
| def clear(self): | |||
| self._dice_coeff_sum = 0 | |||
| self._iou_sum = 0 | |||
| self._samples_num = 0 | |||
| def update(self, *inputs): | |||
| if len(inputs) != 2: | |||
| raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||
| y_pred = inputs[0] | |||
| raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) | |||
| y = np.array(inputs[1]) | |||
| self._samples_num += y.shape[0] | |||
| y_pred = y_pred.transpose(0, 2, 3, 1) | |||
| y = y.transpose(0, 2, 3, 1) | |||
| y_pred = softmax(y_pred, axis=3) | |||
| b, h, w, c = y.shape | |||
| if b != 1: | |||
| raise ValueError('Batch size should be 1 when in evaluation.') | |||
| y = y.reshape((h, w, c)) | |||
| if cfg_unet["eval_activate"].lower() == "softmax": | |||
| y_softmax = np.squeeze(inputs[0][0], axis=0) | |||
| if cfg_unet["eval_resize"]: | |||
| y_pred = [] | |||
| for m in range(cfg_unet["num_classes"]): | |||
| y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| y_pred = y_softmax | |||
| elif cfg_unet["eval_activate"].lower() == "argmax": | |||
| y_argmax = np.squeeze(inputs[0][1], axis=0) | |||
| y_pred = [] | |||
| for n in range(cfg_unet["num_classes"]): | |||
| if cfg_unet["eval_resize"]: | |||
| y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST)) | |||
| else: | |||
| y_pred.append(np.float32(y_argmax == n)) | |||
| y_pred = np.stack(y_pred, axis=-1) | |||
| else: | |||
| raise ValueError('config eval_activate should be softmax or argmax.') | |||
| y_pred = y_pred.astype(np.float32) | |||
| inter = np.dot(y_pred.flatten(), y.flatten()) | |||
| union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) | |||
| single_dice_coeff = 2*float(inter)/float(union+1e-6) | |||
| print("single dice coeff is:", single_dice_coeff) | |||
| single_iou = single_dice_coeff / (2 - single_dice_coeff) | |||
| print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) | |||
| self._dice_coeff_sum += single_dice_coeff | |||
| self._iou_sum += single_iou | |||
| def eval(self): | |||
| if self._samples_num == 0: | |||
| raise RuntimeError('Total samples num must not be 0.') | |||
| return self._dice_coeff_sum / float(self._samples_num) | |||
| return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) | |||
| def test_net(data_dir, | |||
| cross_valid_ind=1, | |||
| cfg=None): | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], | |||
| img_size=cfg['img_size']) | |||
| if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": | |||
| valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, | |||
| eval_resize=cfg["eval_resize"], split=0.8) | |||
| else: | |||
| _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], | |||
| img_size=cfg['img_size']) | |||
| labels_list = [] | |||
| for data in valid_dataset: | |||
| @@ -89,10 +113,25 @@ if __name__ == '__main__': | |||
| rst_path = args.rst_path | |||
| metrics = dice_coeff() | |||
| for j in range(len(os.listdir(rst_path))): | |||
| file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" | |||
| output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576) | |||
| label = label_list[j] | |||
| metrics.update(output, label) | |||
| print("Cross valid dice coeff is: ", metrics.eval()) | |||
| if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": | |||
| for i, bin_name in enumerate(os.listdir('./preprocess_Result/')): | |||
| bin_name_softmax = bin_name.replace(".png", "") + "_0.bin" | |||
| bin_name_argmax = bin_name.replace(".png", "") + "_1.bin" | |||
| file_name_sof = rst_path + bin_name_softmax | |||
| file_name_arg = rst_path + bin_name_argmax | |||
| softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2) | |||
| argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96) | |||
| label = label_list[i] | |||
| metrics.update((softmax_out, argmax_out), label) | |||
| else: | |||
| for j in range(len(os.listdir('./preprocess_Result/'))): | |||
| file_name_sof = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" | |||
| file_name_arg = rst_path + "ISBI_test_bs_1_" + str(j) + "_1" + ".bin" | |||
| softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 576, 576, 2) | |||
| argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 576, 576) | |||
| label = label_list[j] | |||
| metrics.update((softmax_out, argmax_out), label) | |||
| eval_score = metrics.eval() | |||
| print("============== Cross valid dice coeff is:", eval_score[0]) | |||
| print("============== Cross valid IOU is:", eval_score[1]) | |||
| @@ -14,6 +14,10 @@ | |||
| # ============================================================================ | |||
| """unet 310 infer preprocess dataset""" | |||
| import argparse | |||
| import os | |||
| import numpy as np | |||
| import cv2 | |||
| from src.data_loader import create_dataset | |||
| from src.config import cfg_unet | |||
| @@ -29,6 +33,56 @@ def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None): | |||
| data[0].asnumpy().tofile(file_path) | |||
| class CellNucleiDataset: | |||
| """ | |||
| Cell nuclei dataset preprocess class. | |||
| """ | |||
| def __init__(self, data_dir, repeat, result_path, is_train=False, split=0.8): | |||
| self.data_dir = data_dir | |||
| self.img_ids = sorted(next(os.walk(self.data_dir))[1]) | |||
| self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat | |||
| np.random.shuffle(self.train_ids) | |||
| self.val_ids = self.img_ids[int(len(self.img_ids) * split):] | |||
| self.is_train = is_train | |||
| self.result_path = result_path | |||
| self._preprocess_dataset() | |||
| def _preprocess_dataset(self): | |||
| for img_id in self.val_ids: | |||
| path = os.path.join(self.data_dir, img_id) | |||
| img = cv2.imread(os.path.join(path, "images", img_id + ".png")) | |||
| if len(img.shape) == 2: | |||
| img = np.expand_dims(img, axis=-1) | |||
| img = np.concatenate([img, img, img], axis=-1) | |||
| mask = [] | |||
| for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: | |||
| mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) | |||
| mask.append(mask_) | |||
| mask = np.max(mask, axis=0) | |||
| cv2.imwrite(os.path.join(self.result_path, img_id + ".png"), img) | |||
| def _read_img_mask(self, img_id): | |||
| path = os.path.join(self.data_dir, img_id) | |||
| img = cv2.imread(os.path.join(path, "image.png")) | |||
| mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) | |||
| return img, mask | |||
| def __getitem__(self, index): | |||
| if self.is_train: | |||
| return self._read_img_mask(self.train_ids[index]) | |||
| return self._read_img_mask(self.val_ids[index]) | |||
| @property | |||
| def column_names(self): | |||
| column_names = ['image', 'mask'] | |||
| return column_names | |||
| def __len__(self): | |||
| if self.is_train: | |||
| return len(self.train_ids) | |||
| return len(self.val_ids) | |||
| def get_args(): | |||
| parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ', | |||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
| @@ -42,5 +96,8 @@ def get_args(): | |||
| if __name__ == '__main__': | |||
| args = get_args() | |||
| preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, result_path= | |||
| args.result_path) | |||
| if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": | |||
| cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8) | |||
| else: | |||
| preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, | |||
| result_path=args.result_path) | |||
| @@ -14,9 +14,10 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [[ $# -lt 2 || $# -gt 3 ]]; then | |||
| echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] | |||
| DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" | |||
| if [[ $# -lt 3 || $# -gt 4 ]]; then | |||
| echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [NEED_PREPROCESS] | |||
| DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero. | |||
| NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'." | |||
| exit 1 | |||
| fi | |||
| @@ -29,7 +30,7 @@ get_real_path(){ | |||
| } | |||
| model=$(get_real_path $1) | |||
| data_path=$(get_real_path $2) | |||
| if [ $# == 3 ]; then | |||
| if [ $# == 4 ]; then | |||
| device_id=$3 | |||
| if [ -z $device_id ]; then | |||
| device_id=0 | |||
| @@ -37,10 +38,12 @@ if [ $# == 3 ]; then | |||
| device_id=$device_id | |||
| fi | |||
| fi | |||
| need_preprocess=$4 | |||
| echo "mindir name: "$model | |||
| echo "dataset path: "$data_path | |||
| echo "device id: "$device_id | |||
| echo "need preprocess or not: "$need_preprocess | |||
| export ASCEND_HOME=/usr/local/Ascend/ | |||
| if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then | |||
| @@ -85,7 +88,7 @@ function infer() | |||
| fi | |||
| mkdir result_Files | |||
| mkdir time_Result | |||
| ../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id &> infer.log | |||
| ../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id --need_preprocess=$need_preprocess &> infer.log | |||
| } | |||
| function cal_acc() | |||