Merge pull request !7010 from linqingke/psenettags/v1.1.0
| @@ -41,9 +41,9 @@ fi | |||||
| python ${current_exec_path}/src/generate_hccn_file.py | python ${current_exec_path}/src/generate_hccn_file.py | ||||
| export DEVICE_NUM=4 | |||||
| export RANK_SIZE=4 | |||||
| export RANK_TABLE_FILE=${current_exec_path}/rank_table_4p.json | |||||
| export DEVICE_NUM=8 | |||||
| export RANK_SIZE=8 | |||||
| export RANK_TABLE_FILE=${current_exec_path}/rank_table_8p.json | |||||
| for((i=0; i<${DEVICE_NUM}; i++)) | for((i=0; i<${DEVICE_NUM}; i++)) | ||||
| do | do | ||||
| @@ -29,6 +29,12 @@ config = ed({ | |||||
| # neck | # neck | ||||
| 'NECK_OUT_CHANNEL': 256, | 'NECK_OUT_CHANNEL': 256, | ||||
| # lr | |||||
| "BASE_LR": 2e-3, | |||||
| "TRAIN_TOTAL_ITER": 58000, | |||||
| "WARMUP_STEP": 620, | |||||
| "WARMUP_RATIO": 1/3, | |||||
| # dataset for train | # dataset for train | ||||
| "TRAIN_ROOT_DIR": 'psenet/ic15/', | "TRAIN_ROOT_DIR": 'psenet/ic15/', | ||||
| "TRAIN_IS_TRANSFORM": True, | "TRAIN_IS_TRANSFORM": True, | ||||
| @@ -37,9 +43,8 @@ config = ed({ | |||||
| "TRAIN_MIN_SCALE": 0.4, | "TRAIN_MIN_SCALE": 0.4, | ||||
| "TRAIN_BUFFER_SIZE": 8, | "TRAIN_BUFFER_SIZE": 8, | ||||
| "TRAIN_BATCH_SIZE": 4, | "TRAIN_BATCH_SIZE": 4, | ||||
| "TRAIN_REPEAT_NUM": 608*4, | |||||
| "TRAIN_REPEAT_NUM": 1800, | |||||
| "TRAIN_DROP_REMAINDER": True, | "TRAIN_DROP_REMAINDER": True, | ||||
| "TRAIN_TOTAL_ITER": 152000, | |||||
| "TRAIN_MODEL_SAVE_PATH": './checkpoints/', | "TRAIN_MODEL_SAVE_PATH": './checkpoints/', | ||||
| # dataset for test | # dataset for test | ||||
| @@ -17,7 +17,7 @@ | |||||
| import os | import os | ||||
| import socket | import socket | ||||
| RANK_TABLE_SAVE_PATH = './rank_table_4p.json' | |||||
| RANK_TABLE_SAVE_PATH = './rank_table_8p.json' | |||||
| def main(): | def main(): | ||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """lr generator for psenet""" | |||||
| import math | |||||
| def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): | |||||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
| learning_rate = float(init_lr) + lr_inc * current_step | |||||
| return learning_rate | |||||
| def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||||
| base = float(current_step - warmup_steps) / float(decay_steps) | |||||
| learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | |||||
| return learning_rate | |||||
| def dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1/3): | |||||
| """dynamic learning rate generator""" | |||||
| lr = [] | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio)) | |||||
| else: | |||||
| lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | |||||
| return lr | |||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| import math | |||||
| import argparse | import argparse | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| @@ -29,6 +28,7 @@ from src.config import config | |||||
| from src.ETSNET.etsnet import ETSNet | from src.ETSNET.etsnet import ETSNet | ||||
| from src.ETSNET.dice_loss import DiceLoss | from src.ETSNET.dice_loss import DiceLoss | ||||
| from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack | from src.network_define import WithLossCell, TrainOneStepCell, LossCallBack | ||||
| from src.lr_schedule import dynamic_lr | |||||
| parser = argparse.ArgumentParser(description='Hyperparams') | parser = argparse.ArgumentParser(description='Hyperparams') | ||||
| parser.add_argument('--run_distribute', default=False, action='store_true', | parser.add_argument('--run_distribute', default=False, action='store_true', | ||||
| @@ -41,10 +41,6 @@ args = parser.parse_args() | |||||
| set_seed(1) | set_seed(1) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) | ||||
| def lr_generator(start_lr, lr_scale, total_iters): | |||||
| lrs = [start_lr * (lr_scale ** math.floor(cur_iter * 1.0 / (total_iters / 3))) for cur_iter in range(total_iters)] | |||||
| return lrs | |||||
| def train(): | def train(): | ||||
| rank_id = 0 | rank_id = 0 | ||||
| if args.run_distribute: | if args.run_distribute: | ||||
| @@ -67,7 +63,7 @@ def train(): | |||||
| criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE) | criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE) | ||||
| lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER) | |||||
| lrs = dynamic_lr(config.BASE_LR, config.TRAIN_TOTAL_ITER, config.WARMUP_STEP, config.WARMUP_RATIO) | |||||
| opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4) | opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4) | ||||
| # warp model | # warp model | ||||