From: @qujianwei Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34pull/14736/MERGE
| @@ -131,6 +131,7 @@ crnn | |||
| │ ├── crnn.py # crnn network definition | |||
| │ ├── crnn_for_train.py # crnn network with grad, loss and gradient clip | |||
| │ ├── dataset.py # Data preprocessing for training and evaluation | |||
| │ ├── eval_callback.py | |||
| │ ├── ic03_dataset.py # Data preprocessing for IC03 | |||
| │ ├── ic13_dataset.py # Data preprocessing for IC13 | |||
| │ ├── iiit5k_dataset.py # Data preprocessing for IIIT5K | |||
| @@ -225,6 +226,10 @@ Check the `eval/log.txt` and you will get outputs as following: | |||
| result: {'CRNNAccuracy': (0.806)} | |||
| ``` | |||
| ### Evaluation while training | |||
| You can add `run_eval` to start shell and set it True.You need also add `eval_dataset` to select which dataset to eval, and add eval_dataset_path to start shell if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. | |||
| ## [Inference Process](#contents) | |||
| ### [Export MindIR](#contents) | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation callback when training""" | |||
| import os | |||
| import stat | |||
| from mindspore import save_checkpoint | |||
| from mindspore import log as logger | |||
| from mindspore.train.callback import Callback | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Evaluation callback when training. | |||
| Args: | |||
| eval_function (function): evaluation function. | |||
| eval_param_dict (dict): evaluation parameters' configure dict. | |||
| interval (int): run evaluation interval, default is 1. | |||
| eval_start_epoch (int): evaluation start epoch, default is 1. | |||
| save_best_ckpt (bool): Whether to save best checkpoint, default is True. | |||
| besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | |||
| metrics_name (str): evaluation metrics name, default is `acc`. | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> EvalCallBack(eval_function, eval_param_dict) | |||
| """ | |||
| def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | |||
| ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | |||
| super(EvalCallBack, self).__init__() | |||
| self.eval_param_dict = eval_param_dict | |||
| self.eval_function = eval_function | |||
| self.eval_start_epoch = eval_start_epoch | |||
| if interval < 1: | |||
| raise ValueError("interval should >= 1.") | |||
| self.interval = interval | |||
| self.save_best_ckpt = save_best_ckpt | |||
| self.best_res = 0 | |||
| self.best_epoch = 0 | |||
| if not os.path.isdir(ckpt_directory): | |||
| os.makedirs(ckpt_directory) | |||
| self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | |||
| self.metrics_name = metrics_name | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def epoch_end(self, run_context): | |||
| """Callback when epoch end.""" | |||
| cb_params = run_context.original_args() | |||
| cur_epoch = cb_params.cur_epoch_num | |||
| if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | |||
| res = self.eval_function(self.eval_param_dict) | |||
| print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | |||
| if res >= self.best_res: | |||
| self.best_res = res | |||
| self.best_epoch = cur_epoch | |||
| print("update best result: {}".format(res), flush=True) | |||
| if self.save_best_ckpt: | |||
| if os.path.exists(self.bast_ckpt_path): | |||
| self.remove_ckpoint_file(self.bast_ckpt_path) | |||
| save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | |||
| print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | |||
| def end(self, run_context): | |||
| print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | |||
| self.best_res, | |||
| self.best_epoch), flush=True) | |||
| @@ -15,6 +15,7 @@ | |||
| """crnn training""" | |||
| import os | |||
| import argparse | |||
| import ast | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common import set_seed | |||
| @@ -28,7 +29,8 @@ from src.loss import CTCLoss | |||
| from src.dataset import create_dataset | |||
| from src.crnn import crnn | |||
| from src.crnn_for_train import TrainOneStepCellWithGradClip | |||
| from src.metric import CRNNAccuracy | |||
| from src.eval_callback import EvalCallBack | |||
| set_seed(1) | |||
| parser = argparse.ArgumentParser(description="crnn training") | |||
| @@ -38,6 +40,16 @@ parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'] | |||
| help='Running platform, only support Ascend now. Default is Ascend.') | |||
| parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase") | |||
| parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) | |||
| parser.add_argument('--eval_dataset', type=str, default='svt', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) | |||
| parser.add_argument('--eval_dataset_path', type=str, default=None, help='Dataset path, default is None') | |||
| parser.add_argument("--run_eval", type=ast.literal_eval, default=False, | |||
| help="Run evaluation when training, default is False.") | |||
| parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, | |||
| help="Save best checkpoint when run_eval is True, default is True.") | |||
| parser.add_argument("--eval_start_epoch", type=int, default=5, | |||
| help="Evaluation start epoch when run_eval is True, default is 5.") | |||
| parser.add_argument("--eval_interval", type=int, default=5, | |||
| help="Evaluation interval when run_eval is True, default is 5.") | |||
| parser.set_defaults(run_distribute=False) | |||
| args_opt = parser.parse_args() | |||
| @@ -50,6 +62,12 @@ if args_opt.platform == 'Ascend': | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(device_id=device_id) | |||
| def apply_eval(eval_param): | |||
| evaluation_model = eval_param["model"] | |||
| eval_ds = eval_param["dataset"] | |||
| metrics_name = eval_param["metrics_name"] | |||
| res = evaluation_model.eval(eval_ds) | |||
| return res[metrics_name] | |||
| if __name__ == '__main__': | |||
| lr_scale = 1 | |||
| @@ -86,16 +104,31 @@ if __name__ == '__main__': | |||
| net = crnn(config) | |||
| opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov) | |||
| net = WithLossCell(net, loss) | |||
| net = TrainOneStepCellWithGradClip(net, opt).set_train() | |||
| net_with_loss = WithLossCell(net, loss) | |||
| net_with_grads = TrainOneStepCellWithGradClip(net_with_loss, opt).set_train() | |||
| # define model | |||
| model = Model(net) | |||
| model = Model(net_with_grads) | |||
| # define callbacks | |||
| callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||
| if args_opt.run_eval: | |||
| if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)): | |||
| raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path)) | |||
| eval_dataset = create_dataset(name=args_opt.eval_dataset, | |||
| dataset_path=args_opt.eval_dataset_path, | |||
| batch_size=config.batch_size, | |||
| is_training=False, | |||
| config=config) | |||
| eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)}) | |||
| eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"} | |||
| eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, | |||
| eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, | |||
| ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt", | |||
| metrics_name="acc") | |||
| callbacks += [eval_cb] | |||
| if config.save_checkpoint and rank == 0: | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||
| ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck) | |||
| callbacks.append(ckpt_cb) | |||
| model.train(config.epoch_size, dataset, callbacks=callbacks) | |||
| @@ -96,6 +96,8 @@ Here we used 6 datasets for training, and 1 datasets for Evaluation. | |||
| │ ├── create_dataset.py # create mindrecord dataset | |||
| │ ├── ctpn.py # ctpn network definition | |||
| │ ├── dataset.py # data proprocessing | |||
| │ ├── eval_callback.py # evaluation callback while training | |||
| │ ├── eval_utils.py # evaluation function | |||
| │ ├── lr_schedule.py # learning rate scheduler | |||
| │ ├── network_define.py # network definition | |||
| │ └── text_connector | |||
| @@ -235,6 +237,10 @@ Then you can run the scripts/eval_res.sh to calculate the evalulation result. | |||
| bash eval_res.sh | |||
| ``` | |||
| ### Evaluation while training | |||
| You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. | |||
| ### Result | |||
| Evaluation result will be stored in the example path, you can find result like the followings in `log`. | |||
| @@ -14,17 +14,14 @@ | |||
| # ============================================================================ | |||
| """Evaluation for CTPN""" | |||
| import os | |||
| import argparse | |||
| import time | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import set_seed | |||
| from src.ctpn import CTPN | |||
| from src.config import config | |||
| from src.dataset import create_ctpn_dataset | |||
| from src.text_connector.detector import detect | |||
| from src.eval_utils import eval_for_ctpn | |||
| set_seed(1) | |||
| parser = argparse.ArgumentParser(description="CTPN evaluation") | |||
| @@ -39,80 +36,13 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''): | |||
| """ctpn infer.""" | |||
| print("ckpt path is {}".format(ckpt_path)) | |||
| ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False) | |||
| config.batch_size = config.test_batch_size | |||
| total = ds.get_dataset_size() | |||
| print("*************total dataset size is {}".format(total)) | |||
| net = CTPN(config, is_training=False) | |||
| print("eval dataset size is {}".format(total)) | |||
| net = CTPN(config, batch_size=config.test_batch_size, is_training=False) | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| eval_iter = 0 | |||
| print("\n========================================\n") | |||
| print("Processing, please wait a moment.") | |||
| img_basenames = [] | |||
| output_dir = os.path.join(os.getcwd(), "submit") | |||
| if not os.path.exists(output_dir): | |||
| os.mkdir(output_dir) | |||
| for file in os.listdir(img_dir): | |||
| img_basenames.append(os.path.basename(file)) | |||
| for data in ds.create_dict_iterator(): | |||
| img_data = data['image'] | |||
| img_metas = data['image_shape'] | |||
| gt_bboxes = data['box'] | |||
| gt_labels = data['label'] | |||
| gt_num = data['valid_num'] | |||
| start = time.time() | |||
| # run net | |||
| output = net(img_data, gt_bboxes, gt_labels, gt_num) | |||
| gt_bboxes = gt_bboxes.asnumpy() | |||
| gt_labels = gt_labels.asnumpy() | |||
| gt_num = gt_num.asnumpy().astype(bool) | |||
| end = time.time() | |||
| proposal = output[0] | |||
| proposal_mask = output[1] | |||
| print("start to draw pic") | |||
| for j in range(config.test_batch_size): | |||
| img = img_basenames[config.test_batch_size * eval_iter + j] | |||
| all_box_tmp = proposal[j].asnumpy() | |||
| all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1) | |||
| using_boxes_mask = all_box_tmp * all_mask_tmp | |||
| textsegs = using_boxes_mask[:, 0:4].astype(np.float32) | |||
| scores = using_boxes_mask[:, 4].astype(np.float32) | |||
| shape = img_metas.asnumpy()[0][:2].astype(np.int32) | |||
| bboxes = detect(textsegs, scores[:, np.newaxis], shape) | |||
| from PIL import Image, ImageDraw | |||
| im = Image.open(img_dir + '/' + img) | |||
| draw = ImageDraw.Draw(im) | |||
| image_h = img_metas.asnumpy()[j][2] | |||
| image_w = img_metas.asnumpy()[j][3] | |||
| gt_boxs = gt_bboxes[j][gt_num[j], :] | |||
| for gt_box in gt_boxs: | |||
| gt_x1 = gt_box[0] / image_w | |||
| gt_y1 = gt_box[1] / image_h | |||
| gt_x2 = gt_box[2] / image_w | |||
| gt_y2 = gt_box[3] / image_h | |||
| draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\ | |||
| fill='green', width=2) | |||
| file_name = "res_" + img.replace("jpg", "txt") | |||
| output_file = os.path.join(output_dir, file_name) | |||
| f = open(output_file, 'w') | |||
| for bbox in bboxes: | |||
| x1 = bbox[0] / image_w | |||
| y1 = bbox[1] / image_h | |||
| x2 = bbox[2] / image_w | |||
| y2 = bbox[3] / image_h | |||
| draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2) | |||
| str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2)) | |||
| f.write(str_tmp) | |||
| f.write("\n") | |||
| f.close() | |||
| im.save(img) | |||
| percent = round(eval_iter / total * 100, 2) | |||
| eval_iter = eval_iter + 1 | |||
| print("Iter {} cost time {}".format(eval_iter, end - start)) | |||
| print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r') | |||
| eval_for_ctpn(net, ds, img_dir) | |||
| if __name__ == '__main__': | |||
| ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path) | |||
| @@ -36,7 +36,7 @@ if args.device_target == "Ascend": | |||
| context.set_context(device_id=args.device_id) | |||
| if __name__ == '__main__': | |||
| net = CTPN_Infer(config=config) | |||
| net = CTPN_Infer(config=config, batch_size=config.test_batch_size) | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| @@ -56,6 +56,8 @@ do | |||
| export RANK_ID=$i | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp ./*.py ./train_parallel$i | |||
| cp ./*.zip ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| cp *.sh ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| @@ -29,8 +29,7 @@ finetune_config = EasyDict({ | |||
| "total_epoch": 50, | |||
| }) | |||
| # use for low case number | |||
| config = EasyDict({ | |||
| config_default = EasyDict({ | |||
| "img_width": 960, | |||
| "img_height": 576, | |||
| "keep_ratio": False, | |||
| @@ -39,7 +38,6 @@ config = EasyDict({ | |||
| "expand_ratio": 1.0, | |||
| # anchor | |||
| "feature_shapes": (36, 60), | |||
| "num_anchors": 14, | |||
| "anchor_base": 16, | |||
| "anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406], | |||
| @@ -56,7 +54,6 @@ config = EasyDict({ | |||
| "neg_iou_thr": 0.5, | |||
| "pos_iou_thr": 0.7, | |||
| "min_pos_iou": 0.001, | |||
| "num_bboxes": 30240, | |||
| "num_gts": 256, | |||
| "num_expected_neg": 512, | |||
| "num_expected_pos": 256, | |||
| @@ -75,12 +72,11 @@ config = EasyDict({ | |||
| # rnn structure | |||
| "input_size": 512, | |||
| "num_step": 60, | |||
| "rnn_batch_size": 36, | |||
| "hidden_size": 128, | |||
| # training | |||
| "warmup_mode": "linear", | |||
| # batch_size only support 1 | |||
| "batch_size": 1, | |||
| "momentum": 0.9, | |||
| "save_checkpoint": True, | |||
| @@ -131,3 +127,12 @@ config = EasyDict({ | |||
| "pretraining_dataset_file": "", | |||
| "finetune_dataset_file": "" | |||
| }) | |||
| config_add = { | |||
| "feature_shapes": (config_default["img_height"] // 16, config_default["img_width"] // 16), | |||
| "num_bboxes": (config_default["img_height"] // 16) * \ | |||
| (config_default["img_width"] // 16) *config_default["num_anchors"], | |||
| "num_step": config_default["img_width"] // 16, | |||
| "rnn_batch_size": config_default["img_height"] // 16 | |||
| } | |||
| config = EasyDict({**config_default, **config_add}) | |||
| @@ -145,7 +145,7 @@ def create_train_dataset(dataset_type): | |||
| # test: icdar2013 test | |||
| icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\ | |||
| config.icdar13_test_path[1], "") | |||
| image_files = icdar_test_image_files | |||
| image_files = sorted(icdar_test_image_files) | |||
| image_anno_dict = icdar_test_anno_dict | |||
| data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \ | |||
| prefix="ctpn_test.mindrecord", file_num=1) | |||
| @@ -29,16 +29,13 @@ class BiLSTM(nn.Cell): | |||
| Define a BiLSTM network which contains two LSTM layers | |||
| Args: | |||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||
| captcha images. | |||
| batch_size(int): batch size of input data, default is 64 | |||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||
| config(EasyDict): config for ctpn network | |||
| batch_size(int): batch size of input data, only support 1 | |||
| """ | |||
| def __init__(self, config, is_training=True): | |||
| def __init__(self, config, batch_size): | |||
| super(BiLSTM, self).__init__() | |||
| self.is_training = is_training | |||
| self.batch_size = config.batch_size * config.rnn_batch_size | |||
| print("batch size is {} ".format(self.batch_size)) | |||
| self.batch_size = batch_size | |||
| self.batch_size = self.batch_size * config.rnn_batch_size | |||
| self.input_size = config.input_size | |||
| self.hidden_size = config.hidden_size | |||
| self.num_step = config.num_step | |||
| @@ -84,25 +81,24 @@ class CTPN(nn.Cell): | |||
| Define CTPN network | |||
| Args: | |||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||
| captcha images. | |||
| batch_size(int): batch size of input data, default is 64 | |||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||
| config(EasyDict): config for ctpn network | |||
| batch_size(int): batch size of input data, only support 1 | |||
| is_training(bool): whether training, default is True | |||
| """ | |||
| def __init__(self, config, is_training=True): | |||
| def __init__(self, config, batch_size, is_training=True): | |||
| super(CTPN, self).__init__() | |||
| self.config = config | |||
| self.is_training = is_training | |||
| self.batch_size = batch_size | |||
| self.num_step = config.num_step | |||
| self.input_size = config.input_size | |||
| self.batch_size = config.batch_size | |||
| self.hidden_size = config.hidden_size | |||
| self.vgg16_feature_extractor = VGG16FeatureExtraction() | |||
| self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same') | |||
| self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16) | |||
| self.rnn = BiLSTM(self.config, batch_size=self.batch_size).to_float(mstype.float16) | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.cast = P.Cast() | |||
| self.is_training = is_training | |||
| # rpn block | |||
| self.rpn_with_loss = RPN(config, | |||
| @@ -115,7 +111,7 @@ class CTPN(nn.Cell): | |||
| self.featmap_size = config.feature_shapes | |||
| self.anchor_list = self.get_anchors(self.featmap_size) | |||
| self.proposal_generator_test = Proposal(config, | |||
| config.test_batch_size, | |||
| self.batch_size, | |||
| config.activate_num_classes, | |||
| config.use_sigmoid_cls) | |||
| self.proposal_generator_test.set_train_local(config, False) | |||
| @@ -143,9 +139,9 @@ class CTPN(nn.Cell): | |||
| return Tensor(anchors, mstype.float16) | |||
| class CTPN_Infer(nn.Cell): | |||
| def __init__(self, config): | |||
| def __init__(self, config, batch_size): | |||
| super(CTPN_Infer, self).__init__() | |||
| self.network = CTPN(config, is_training=False) | |||
| self.network = CTPN(config, batch_size=batch_size, is_training=False) | |||
| self.network.set_train(False) | |||
| def construct(self, img_data): | |||
| @@ -289,11 +289,11 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num= | |||
| input_columns=["image", "annotation"], | |||
| output_columns=["image", "box", "label", "valid_num", "image_shape"], | |||
| column_order=["image", "box", "label", "valid_num", "image_shape"], | |||
| num_parallel_workers=num_parallel_workers, | |||
| num_parallel_workers=8, | |||
| python_multiprocessing=True) | |||
| ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], | |||
| num_parallel_workers=24) | |||
| num_parallel_workers=8) | |||
| # transpose_column from python to c | |||
| ds = ds.map(operations=[type_cast1], input_columns=["image_shape"]) | |||
| ds = ds.map(operations=[type_cast1], input_columns=["box"]) | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation callback when training""" | |||
| import os | |||
| import stat | |||
| from mindspore import save_checkpoint | |||
| from mindspore import log as logger | |||
| from mindspore.train.callback import Callback | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Evaluation callback when training. | |||
| Args: | |||
| eval_function (function): evaluation function. | |||
| eval_param_dict (dict): evaluation parameters' configure dict. | |||
| interval (int): run evaluation interval, default is 1. | |||
| eval_start_epoch (int): evaluation start epoch, default is 1. | |||
| save_best_ckpt (bool): Whether to save best checkpoint, default is True. | |||
| besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. | |||
| metrics_name (str): evaluation metrics name, default is `acc`. | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> EvalCallBack(eval_function, eval_param_dict) | |||
| """ | |||
| def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, | |||
| ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): | |||
| super(EvalCallBack, self).__init__() | |||
| self.eval_param_dict = eval_param_dict | |||
| self.eval_function = eval_function | |||
| self.eval_start_epoch = eval_start_epoch | |||
| if interval < 1: | |||
| raise ValueError("interval should >= 1.") | |||
| self.interval = interval | |||
| self.save_best_ckpt = save_best_ckpt | |||
| self.best_res = 0 | |||
| self.best_epoch = 0 | |||
| if not os.path.isdir(ckpt_directory): | |||
| os.makedirs(ckpt_directory) | |||
| self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) | |||
| self.metrics_name = metrics_name | |||
| def remove_ckpoint_file(self, file_name): | |||
| """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" | |||
| try: | |||
| os.chmod(file_name, stat.S_IWRITE) | |||
| os.remove(file_name) | |||
| except OSError: | |||
| logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) | |||
| except ValueError: | |||
| logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) | |||
| def epoch_end(self, run_context): | |||
| """Callback when epoch end.""" | |||
| cb_params = run_context.original_args() | |||
| cur_epoch = cb_params.cur_epoch_num | |||
| if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: | |||
| res = self.eval_function(self.eval_param_dict) | |||
| print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) | |||
| if res >= self.best_res: | |||
| self.best_res = res | |||
| self.best_epoch = cur_epoch | |||
| print("update best result: {}".format(res), flush=True) | |||
| if self.save_best_ckpt: | |||
| if os.path.exists(self.bast_ckpt_path): | |||
| self.remove_ckpoint_file(self.bast_ckpt_path) | |||
| save_checkpoint(cb_params.train_network, self.bast_ckpt_path) | |||
| print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) | |||
| def end(self, run_context): | |||
| print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, | |||
| self.best_res, | |||
| self.best_epoch), flush=True) | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation utils for CTPN""" | |||
| import os | |||
| import subprocess | |||
| import numpy as np | |||
| from src.config import config | |||
| from src.text_connector.detector import detect | |||
| def exec_shell_cmd(cmd): | |||
| sub = subprocess.Popen(args="{}".format(cmd), shell=True, stdin=subprocess.PIPE, \ | |||
| stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) | |||
| stdout_data, _ = sub.communicate() | |||
| if sub.returncode != 0: | |||
| raise ValueError("{} is not a executable command, please check.".format(cmd)) | |||
| return stdout_data.strip() | |||
| def get_eval_result(): | |||
| create_eval_bbox = 'rm -rf submit*.zip;cd ./submit/;zip -r ../submit.zip *.txt;cd ../;bash eval_res.sh' | |||
| os.system(create_eval_bbox) | |||
| get_eval_output = "grep hmean log | awk '{print $NF}' | awk -F} '{print $1}' |tail -n 1" | |||
| hmean = exec_shell_cmd(get_eval_output) | |||
| return float(hmean) | |||
| def eval_for_ctpn(network, dataset, eval_image_path): | |||
| network.set_train(False) | |||
| eval_iter = 0 | |||
| img_basenames = [] | |||
| output_dir = os.path.join(os.getcwd(), "submit") | |||
| if not os.path.exists(output_dir): | |||
| os.mkdir(output_dir) | |||
| for file in os.listdir(eval_image_path): | |||
| img_basenames.append(os.path.basename(file)) | |||
| img_basenames = sorted(img_basenames) | |||
| for data in dataset.create_dict_iterator(): | |||
| img_data = data['image'] | |||
| img_metas = data['image_shape'] | |||
| gt_bboxes = data['box'] | |||
| gt_labels = data['label'] | |||
| gt_num = data['valid_num'] | |||
| # run net | |||
| output = network(img_data, gt_bboxes, gt_labels, gt_num) | |||
| gt_bboxes = gt_bboxes.asnumpy() | |||
| gt_labels = gt_labels.asnumpy() | |||
| gt_num = gt_num.asnumpy().astype(bool) | |||
| proposal = output[0] | |||
| proposal_mask = output[1] | |||
| for j in range(config.test_batch_size): | |||
| img = img_basenames[config.test_batch_size * eval_iter + j] | |||
| all_box_tmp = proposal[j].asnumpy() | |||
| all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1) | |||
| using_boxes_mask = all_box_tmp * all_mask_tmp | |||
| textsegs = using_boxes_mask[:, 0:4].astype(np.float32) | |||
| scores = using_boxes_mask[:, 4].astype(np.float32) | |||
| shape = img_metas.asnumpy()[0][:2].astype(np.int32) | |||
| bboxes = detect(textsegs, scores[:, np.newaxis], shape) | |||
| from PIL import Image, ImageDraw | |||
| im = Image.open(eval_image_path + '/' + img) | |||
| draw = ImageDraw.Draw(im) | |||
| image_h = img_metas.asnumpy()[j][2] | |||
| image_w = img_metas.asnumpy()[j][3] | |||
| gt_boxs = gt_bboxes[j][gt_num[j], :] | |||
| for gt_box in gt_boxs: | |||
| gt_x1 = gt_box[0] / image_w | |||
| gt_y1 = gt_box[1] / image_h | |||
| gt_x2 = gt_box[2] / image_w | |||
| gt_y2 = gt_box[3] / image_h | |||
| draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\ | |||
| fill='green', width=2) | |||
| file_name = "res_" + img.replace("jpg", "txt") | |||
| output_file = os.path.join(output_dir, file_name) | |||
| f = open(output_file, 'w') | |||
| for bbox in bboxes: | |||
| x1 = bbox[0] / image_w | |||
| y1 = bbox[1] / image_h | |||
| x2 = bbox[2] / image_w | |||
| y2 = bbox[3] / image_h | |||
| draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2) | |||
| str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2)) | |||
| f.write(str_tmp) | |||
| f.write("\n") | |||
| f.close() | |||
| im.save(img) | |||
| eval_iter = eval_iter + 1 | |||
| @@ -32,6 +32,8 @@ from src.config import config, pretrain_config, finetune_config | |||
| from src.dataset import create_ctpn_dataset | |||
| from src.lr_schedule import dynamic_lr | |||
| from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell | |||
| from src.eval_utils import eval_for_ctpn, get_eval_result | |||
| from src.eval_callback import EvalCallBack | |||
| set_seed(1) | |||
| @@ -43,10 +45,30 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums, | |||
| parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") | |||
| parser.add_argument("--task_type", type=str, default="Pretraining",\ | |||
| choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining") | |||
| parser.add_argument("--run_eval", type=ast.literal_eval, default=False, \ | |||
| help="Run evaluation when training, default is False.") | |||
| parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, \ | |||
| help="Save best checkpoint when run_eval is True, default is True.") | |||
| parser.add_argument("--eval_image_path", type=str, default="", \ | |||
| help="eval image path, when run_eval is True, eval_image_path should be set.") | |||
| parser.add_argument("--eval_dataset_path", type=str, default="", \ | |||
| help="eval dataset path, when run_eval is True, eval_dataset_path should be set.") | |||
| parser.add_argument("--eval_start_epoch", type=int, default=10, \ | |||
| help="Evaluation start epoch when run_eval is True, default is 10.") | |||
| parser.add_argument("--eval_interval", type=int, default=10, \ | |||
| help="Evaluation interval when run_eval is True, default is 10.") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True) | |||
| def apply_eval(eval_param): | |||
| network = eval_param["eval_network"] | |||
| eval_ds = eval_param["eval_dataset"] | |||
| eval_image_path = eval_param["eval_image_path"] | |||
| eval_for_ctpn(network, eval_ds, eval_image_path) | |||
| hmean = get_eval_result() | |||
| return hmean | |||
| if __name__ == '__main__': | |||
| if args_opt.run_distribute: | |||
| rank = args_opt.rank_id | |||
| @@ -78,7 +100,7 @@ if __name__ == '__main__': | |||
| dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\ | |||
| batch_size=config.batch_size, device_num=device_num, rank_id=rank) | |||
| dataset_size = dataset.get_dataset_size() | |||
| net = CTPN(config=config, is_training=True) | |||
| net = CTPN(config=config, batch_size=config.batch_size) | |||
| net = net.set_train() | |||
| load_path = args_opt.pre_trained | |||
| @@ -100,20 +122,34 @@ if __name__ == '__main__': | |||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| if args_opt.run_distribute: | |||
| net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, | |||
| mean=True, degree=device_num) | |||
| net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, \ | |||
| mean=True, degree=device_num) | |||
| else: | |||
| net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale) | |||
| net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale) | |||
| time_cb = TimeMonitor(data_size=dataset_size) | |||
| loss_cb = LossCallBack(rank_id=rank) | |||
| cb = [time_cb, loss_cb] | |||
| save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/") | |||
| if config.save_checkpoint: | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/") | |||
| ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig) | |||
| cb += [ckpoint_cb] | |||
| model = Model(net) | |||
| if args_opt.run_eval: | |||
| if args_opt.eval_dataset_path is None or (not os.path.isfile(args_opt.eval_dataset_path)): | |||
| raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path)) | |||
| if args_opt.eval_image_path is None or (not os.path.isdir(args_opt.eval_image_path)): | |||
| raise ValueError("{} is not a existing path.".format(args_opt.eval_image_path)) | |||
| eval_dataset = create_ctpn_dataset(args_opt.eval_dataset_path, \ | |||
| batch_size=config.batch_size, repeat_num=1, is_training=False) | |||
| eval_net = net | |||
| eval_param_dict = {"eval_network": eval_net, "eval_dataset": eval_dataset, \ | |||
| "eval_image_path": args_opt.eval_image_path} | |||
| eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, | |||
| eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, | |||
| ckpt_directory=save_checkpoint_path, besk_ckpt_name="best_acc.ckpt", | |||
| metrics_name="hmean") | |||
| cb += [eval_cb] | |||
| model = Model(net_with_grads) | |||
| model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True) | |||