Browse Source

!15129 fix ctpn vgg16 backbone training

From: @qujianwei
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
tags/v1.2.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
210a4a2490
4 changed files with 26 additions and 25 deletions
  1. +13
    -1
      model_zoo/official/cv/ctpn/README.md
  2. +1
    -0
      model_zoo/official/cv/ctpn/eval.py
  3. +11
    -23
      model_zoo/official/cv/ctpn/src/CTPN/vgg16.py
  4. +1
    -1
      model_zoo/official/cv/ctpn/src/create_dataset.py

+ 13
- 1
model_zoo/official/cv/ctpn/README.md View File

@@ -169,11 +169,23 @@ The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. T
...
from src.vgg16 import VGG16
...
network = VGG16()
network = VGG16(num_classes=cfg.num_classes)
...

```

To train a better model, you can modify some parameter in modelzoo/official/cv/vgg16/src/config.py, here we suggested you modify the "warmup_epochs" just like below, you can also try to adjust other parameter.

```python

imagenet_cfg = edict({
...
"warmup_epochs": 5
...
})

```

Then you can train it with ImageNet2012.
> Notes:
> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size.


+ 1
- 0
model_zoo/official/cv/ctpn/eval.py View File

@@ -56,6 +56,7 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
os.mkdir(output_dir)
for file in os.listdir(img_dir):
img_basenames.append(os.path.basename(file))
img_basenames = sorted(img_basenames)
for data in ds.create_dict_iterator():
img_data = data['image']
img_metas = data['image_shape']


+ 11
- 23
model_zoo/official/cv/ctpn/src/CTPN/vgg16.py View File

@@ -23,36 +23,22 @@ def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)

def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False):
"""Batchnorm2D wrapper."""
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))

return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
beta_init=beta_init, moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics)

def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True):
"""Conv2D wrapper."""
weights = 'ones'
layers = []
conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
pad_mode=pad_mode, weight_init=weights, has_bias=False)
pad_mode=pad_mode, has_bias=False)
if not weights_update:
conv.weight.requires_grad = False
layers += [conv]
layers += [_BatchNorm2dInit(out_channels)]
layers += [nn.BatchNorm2d(out_channels)]
return nn.SequentialCell(layers)


def _fc(in_channels, out_channels):
'''full connection layer'''
weight = _weight_variable((out_channels, in_channels))
bias = _weight_variable((out_channels,))
return nn.Dense(in_channels, out_channels, weight, bias)
return nn.Dense(in_channels, out_channels)


class VGG16FeatureExtraction(nn.Cell):
@@ -141,36 +127,38 @@ class VGG16Classfier(nn.Cell):
self.relu = nn.ReLU()
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096)
self.fc2 = _fc(in_channels=4096, out_channels=4096)
self.batch_size = 32
self.reshape = P.Reshape()
self.dropout = nn.Dropout(0.5)

def construct(self, x):
"""
:param x: shape=(B, 512, 7, 7)
:return:
"""
x = self.reshape(x, (self.batch_size, 7*7*512))
x = self.reshape(x, (-1, 7*7*512))
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.relu(x)
x = self.dropout(x)
return x

class VGG16(nn.Cell):
def __init__(self):
def __init__(self, num_classes):
"""VGG16 construct for training backbone"""
super(VGG16, self).__init__()
self.feature_extraction = VGG16FeatureExtraction(weights_update=True)
self.vgg16_feature_extractor = VGG16FeatureExtraction(weights_update=True)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = VGG16Classfier()
self.fc3 = _fc(in_channels=4096, out_channels=1000)
self.fc3 = _fc(in_channels=4096, out_channels=num_classes)

def construct(self, x):
"""
:param x: shape=(B, 3, 224, 224)
:return: logits, shape=(B, 1000)
"""
feature_maps = self.feature_extraction(x)
feature_maps = self.vgg16_feature_extractor(x)
x = self.max_pool(feature_maps)
x = self.classifier(x)
x = self.fc3(x)


+ 1
- 1
model_zoo/official/cv/ctpn/src/create_dataset.py View File

@@ -145,7 +145,7 @@ def create_train_dataset(dataset_type):
# test: icdar2013 test
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
config.icdar13_test_path[1], "")
image_files = icdar_test_image_files
image_files = sorted(icdar_test_image_files)
image_anno_dict = icdar_test_anno_dict
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
prefix="ctpn_test.mindrecord", file_num=1)


Loading…
Cancel
Save