Browse Source

!7010 psenet 8p accuracy improve.

Merge pull request !7010 from linqingke/psenet
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
afb1a91568
5 changed files with 50 additions and 12 deletions
  1. +3
    -3
      model_zoo/official/cv/psenet/scripts/run_distribute_train.sh
  2. +7
    -2
      model_zoo/official/cv/psenet/src/config.py
  3. +1
    -1
      model_zoo/official/cv/psenet/src/generate_hccn_file.py
  4. +37
    -0
      model_zoo/official/cv/psenet/src/lr_schedule.py
  5. +2
    -6
      model_zoo/official/cv/psenet/train.py

+ 3
- 3
model_zoo/official/cv/psenet/scripts/run_distribute_train.sh View File

@@ -41,9 +41,9 @@ fi

python ${current_exec_path}/src/generate_hccn_file.py

export DEVICE_NUM=4
export RANK_SIZE=4
export RANK_TABLE_FILE=${current_exec_path}/rank_table_4p.json
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=${current_exec_path}/rank_table_8p.json

for((i=0; i<${DEVICE_NUM}; i++))
do


+ 7
- 2
model_zoo/official/cv/psenet/src/config.py View File

@@ -29,6 +29,12 @@ config = ed({
# neck
'NECK_OUT_CHANNEL': 256,
# lr
"BASE_LR": 2e-3,
"TRAIN_TOTAL_ITER": 58000,
"WARMUP_STEP": 620,
"WARMUP_RATIO": 1/3,
# dataset for train
"TRAIN_ROOT_DIR": 'psenet/ic15/',
"TRAIN_IS_TRANSFORM": True,
@@ -37,9 +43,8 @@ config = ed({
"TRAIN_MIN_SCALE": 0.4,
"TRAIN_BUFFER_SIZE": 8,
"TRAIN_BATCH_SIZE": 4,
"TRAIN_REPEAT_NUM": 608*4,
"TRAIN_REPEAT_NUM": 1800,
"TRAIN_DROP_REMAINDER": True,
"TRAIN_TOTAL_ITER": 152000,
"TRAIN_MODEL_SAVE_PATH": './checkpoints/',
# dataset for test


+ 1
- 1
model_zoo/official/cv/psenet/src/generate_hccn_file.py View File

@@ -17,7 +17,7 @@
import os
import socket
RANK_TABLE_SAVE_PATH = './rank_table_4p.json'
RANK_TABLE_SAVE_PATH = './rank_table_8p.json'
def main():


+ 37
- 0
model_zoo/official/cv/psenet/src/lr_schedule.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr generator for psenet"""
import math

def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate

def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate

def dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1/3):
"""dynamic learning rate generator"""
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))

return lr

+ 2
- 6
model_zoo/official/cv/psenet/train.py View File

@@ -14,7 +14,6 @@
# ============================================================================
import math
import argparse
import mindspore.nn as nn
from mindspore import context
@@ -29,6 +28,7 @@ from src.config import config
from src.ETSNET.etsnet import ETSNet
from src.ETSNET.dice_loss import DiceLoss
from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack
from src.lr_schedule import dynamic_lr
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--run_distribute', default=False, action='store_true',
@@ -41,10 +41,6 @@ args = parser.parse_args()
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
def lr_generator(start_lr, lr_scale, total_iters):
lrs = [start_lr * (lr_scale ** math.floor(cur_iter * 1.0 / (total_iters / 3))) for cur_iter in range(total_iters)]
return lrs
def train():
rank_id = 0
if args.run_distribute:
@@ -67,7 +63,7 @@ def train():
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
lrs = dynamic_lr(config.BASE_LR, config.TRAIN_TOTAL_ITER, config.WARMUP_STEP, config.WARMUP_RATIO)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
# warp model


Loading…
Cancel
Save