From ee023b2e029f562a1a37be3149c9cc1a3cf114af Mon Sep 17 00:00:00 2001 From: qujianwei Date: Thu, 8 Apr 2021 19:23:34 +0800 Subject: [PATCH] fix ctpn backbone vgg16 init --- model_zoo/official/cv/ctpn/README.md | 14 +++++++- model_zoo/official/cv/ctpn/eval.py | 1 + model_zoo/official/cv/ctpn/src/CTPN/vgg16.py | 34 ++++++------------- .../official/cv/ctpn/src/create_dataset.py | 2 +- 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/model_zoo/official/cv/ctpn/README.md b/model_zoo/official/cv/ctpn/README.md index 0913cee2fe..19fd87aa1c 100644 --- a/model_zoo/official/cv/ctpn/README.md +++ b/model_zoo/official/cv/ctpn/README.md @@ -169,11 +169,23 @@ The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. T ... from src.vgg16 import VGG16 ... -network = VGG16() +network = VGG16(num_classes=cfg.num_classes) ... ``` +To train a better model, you can modify some parameter in modelzoo/official/cv/vgg16/src/config.py, here we suggested you modify the "warmup_epochs" just like below, you can also try to adjust other parameter. + +```python + +imagenet_cfg = edict({ + ... + "warmup_epochs": 5 + ... +}) + +``` + Then you can train it with ImageNet2012. > Notes: > RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size. diff --git a/model_zoo/official/cv/ctpn/eval.py b/model_zoo/official/cv/ctpn/eval.py index 17e3bfa075..6ef57f9992 100644 --- a/model_zoo/official/cv/ctpn/eval.py +++ b/model_zoo/official/cv/ctpn/eval.py @@ -56,6 +56,7 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''): os.mkdir(output_dir) for file in os.listdir(img_dir): img_basenames.append(os.path.basename(file)) + img_basenames = sorted(img_basenames) for data in ds.create_dict_iterator(): img_data = data['image'] img_metas = data['image_shape'] diff --git a/model_zoo/official/cv/ctpn/src/CTPN/vgg16.py b/model_zoo/official/cv/ctpn/src/CTPN/vgg16.py index 0d06e68b95..8bd9f63849 100644 --- a/model_zoo/official/cv/ctpn/src/CTPN/vgg16.py +++ b/model_zoo/official/cv/ctpn/src/CTPN/vgg16.py @@ -23,36 +23,22 @@ def _weight_variable(shape, factor=0.01): init_value = np.random.randn(*shape).astype(np.float32) * factor return Tensor(init_value) -def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False): - """Batchnorm2D wrapper.""" - gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) - beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) - moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) - moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) - - return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, - beta_init=beta_init, moving_mean_init=moving_mean_init, - moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) - def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True): """Conv2D wrapper.""" - weights = 'ones' layers = [] conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, - pad_mode=pad_mode, weight_init=weights, has_bias=False) + pad_mode=pad_mode, has_bias=False) if not weights_update: conv.weight.requires_grad = False layers += [conv] - layers += [_BatchNorm2dInit(out_channels)] + layers += [nn.BatchNorm2d(out_channels)] return nn.SequentialCell(layers) def _fc(in_channels, out_channels): '''full connection layer''' - weight = _weight_variable((out_channels, in_channels)) - bias = _weight_variable((out_channels,)) - return nn.Dense(in_channels, out_channels, weight, bias) + return nn.Dense(in_channels, out_channels) class VGG16FeatureExtraction(nn.Cell): @@ -141,36 +127,38 @@ class VGG16Classfier(nn.Cell): self.relu = nn.ReLU() self.fc1 = _fc(in_channels=7*7*512, out_channels=4096) self.fc2 = _fc(in_channels=4096, out_channels=4096) - self.batch_size = 32 self.reshape = P.Reshape() + self.dropout = nn.Dropout(0.5) def construct(self, x): """ :param x: shape=(B, 512, 7, 7) :return: """ - x = self.reshape(x, (self.batch_size, 7*7*512)) + x = self.reshape(x, (-1, 7*7*512)) x = self.fc1(x) x = self.relu(x) + x = self.dropout(x) x = self.fc2(x) x = self.relu(x) + x = self.dropout(x) return x class VGG16(nn.Cell): - def __init__(self): + def __init__(self, num_classes): """VGG16 construct for training backbone""" super(VGG16, self).__init__() - self.feature_extraction = VGG16FeatureExtraction(weights_update=True) + self.vgg16_feature_extractor = VGG16FeatureExtraction(weights_update=True) self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) self.classifier = VGG16Classfier() - self.fc3 = _fc(in_channels=4096, out_channels=1000) + self.fc3 = _fc(in_channels=4096, out_channels=num_classes) def construct(self, x): """ :param x: shape=(B, 3, 224, 224) :return: logits, shape=(B, 1000) """ - feature_maps = self.feature_extraction(x) + feature_maps = self.vgg16_feature_extractor(x) x = self.max_pool(feature_maps) x = self.classifier(x) x = self.fc3(x) diff --git a/model_zoo/official/cv/ctpn/src/create_dataset.py b/model_zoo/official/cv/ctpn/src/create_dataset.py index ef9a8faf2c..d0066dd87f 100644 --- a/model_zoo/official/cv/ctpn/src/create_dataset.py +++ b/model_zoo/official/cv/ctpn/src/create_dataset.py @@ -145,7 +145,7 @@ def create_train_dataset(dataset_type): # test: icdar2013 test icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\ config.icdar13_test_path[1], "") - image_files = icdar_test_image_files + image_files = sorted(icdar_test_image_files) image_anno_dict = icdar_test_anno_dict data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \ prefix="ctpn_test.mindrecord", file_num=1)