Browse Source

!3949 support pretrain for maskrcnn

Merge pull request !3949 from meixiaowei/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
e0d144460c
4 changed files with 11 additions and 8 deletions
  1. +1
    -1
      model_zoo/official/cv/maskrcnn/README.md
  2. +1
    -0
      model_zoo/official/cv/maskrcnn/src/config.py
  3. +3
    -3
      model_zoo/official/cv/maskrcnn/src/lr_schedule.py
  4. +6
    -4
      model_zoo/official/cv/maskrcnn/train.py

+ 1
- 1
model_zoo/official/cv/maskrcnn/README.md View File

@@ -35,7 +35,7 @@ MaskRcnn is a two-stage target detection network,This network uses a region prop
└─train2017
```
Notice that the coco2017 dataset will be converted to MindRecord which is a data format in MindSpore. The dataset conversion may take about 4 hours.
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:



+ 1
- 0
model_zoo/official/cv/maskrcnn/src/config.py View File

@@ -134,6 +134,7 @@ config = ed({
"loss_scale": 1,
"momentum": 0.91,
"weight_decay": 1e-4,
"pretrain_epoch_size": 0,
"epoch_size": 12,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,


+ 3
- 3
model_zoo/official/cv/maskrcnn/src/lr_schedule.py View File

@@ -25,7 +25,7 @@ def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate

def dynamic_lr(config, rank_size=1):
def dynamic_lr(config, rank_size=1, start_steps=0):
"""dynamic learning rate generator"""
base_lr = config.base_lr

@@ -38,5 +38,5 @@ def dynamic_lr(config, rank_size=1):
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr
learning_rate = lr[start_steps:]
return learning_rate

+ 6
- 4
model_zoo/official/cv/maskrcnn/train.py View File

@@ -108,13 +108,15 @@ if __name__ == '__main__':
load_path = args_opt.pre_trained
if load_path != "":
param_dict = load_checkpoint(load_path)
for item in list(param_dict.keys()):
if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
param_dict.pop(item)
if config.pretrain_epoch_size == 0:
for item in list(param_dict.keys()):
if not (item.startswith('backbone') or item.startswith('rcnn_mask')):
param_dict.pop(item)
load_param_into_net(net, param_dict)

loss = LossNet()
lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32)
lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size),
mstype.float32)
opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,
weight_decay=config.weight_decay, loss_scale=config.loss_scale)



Loading…
Cancel
Save