|
|
@@ -131,9 +131,7 @@ def conver_training_shape(args): |
|
|
return training_shape |
|
|
return training_shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
|
|
|
"""Train function.""" |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
def network_init(args): |
|
|
devid = int(os.getenv('DEVICE_ID', '0')) |
|
|
devid = int(os.getenv('DEVICE_ID', '0')) |
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, |
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, |
|
|
device_target=args.device_target, save_graphs=True, device_id=devid) |
|
|
device_target=args.device_target, save_graphs=True, device_id=devid) |
|
|
@@ -145,26 +143,21 @@ def train(): |
|
|
init("nccl") |
|
|
init("nccl") |
|
|
args.rank = get_rank() |
|
|
args.rank = get_rank() |
|
|
args.group_size = get_group_size() |
|
|
args.group_size = get_group_size() |
|
|
# select for master rank save ckpt or all rank save, compatiable for model parallel |
|
|
|
|
|
|
|
|
# select for master rank save ckpt or all rank save, compatible for model parallel |
|
|
args.rank_save_ckpt_flag = 0 |
|
|
args.rank_save_ckpt_flag = 0 |
|
|
if args.is_save_on_master: |
|
|
if args.is_save_on_master: |
|
|
if args.rank == 0: |
|
|
if args.rank == 0: |
|
|
args.rank_save_ckpt_flag = 1 |
|
|
args.rank_save_ckpt_flag = 1 |
|
|
else: |
|
|
else: |
|
|
args.rank_save_ckpt_flag = 1 |
|
|
args.rank_save_ckpt_flag = 1 |
|
|
|
|
|
|
|
|
# logger |
|
|
# logger |
|
|
args.outputs_dir = os.path.join(args.ckpt_path, |
|
|
args.outputs_dir = os.path.join(args.ckpt_path, |
|
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) |
|
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) |
|
|
args.logger = get_logger(args.outputs_dir, args.rank) |
|
|
args.logger = get_logger(args.outputs_dir, args.rank) |
|
|
args.logger.save_args(args) |
|
|
args.logger.save_args(args) |
|
|
|
|
|
|
|
|
if args.need_profiler: |
|
|
|
|
|
from mindspore.profiler.profiling import Profiler |
|
|
|
|
|
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) |
|
|
|
|
|
|
|
|
|
|
|
loss_meter = AverageMeter('loss') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parallel_init(args): |
|
|
context.reset_auto_parallel_context() |
|
|
context.reset_auto_parallel_context() |
|
|
parallel_mode = ParallelMode.STAND_ALONE |
|
|
parallel_mode = ParallelMode.STAND_ALONE |
|
|
degree = 1 |
|
|
degree = 1 |
|
|
@@ -173,6 +166,17 @@ def train(): |
|
|
degree = get_group_size() |
|
|
degree = get_group_size() |
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree) |
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree) |
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
|
|
|
"""Train function.""" |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
network_init(args) |
|
|
|
|
|
if args.need_profiler: |
|
|
|
|
|
from mindspore.profiler.profiling import Profiler |
|
|
|
|
|
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) |
|
|
|
|
|
|
|
|
|
|
|
loss_meter = AverageMeter('loss') |
|
|
|
|
|
parallel_init(args) |
|
|
|
|
|
|
|
|
network = YOLOV3DarkNet53(is_training=True) |
|
|
network = YOLOV3DarkNet53(is_training=True) |
|
|
# default is kaiming-normal |
|
|
# default is kaiming-normal |
|
|
default_recurisive_init(network) |
|
|
default_recurisive_init(network) |
|
|
@@ -182,7 +186,6 @@ def train(): |
|
|
args.logger.info('finish get network') |
|
|
args.logger.info('finish get network') |
|
|
|
|
|
|
|
|
config = ConfigYOLOV3DarkNet53() |
|
|
config = ConfigYOLOV3DarkNet53() |
|
|
|
|
|
|
|
|
config.label_smooth = args.label_smooth |
|
|
config.label_smooth = args.label_smooth |
|
|
config.label_smooth_factor = args.label_smooth_factor |
|
|
config.label_smooth_factor = args.label_smooth_factor |
|
|
|
|
|
|
|
|
@@ -202,7 +205,6 @@ def train(): |
|
|
args.ckpt_interval = args.steps_per_epoch |
|
|
args.ckpt_interval = args.steps_per_epoch |
|
|
|
|
|
|
|
|
lr = get_lr(args) |
|
|
lr = get_lr(args) |
|
|
|
|
|
|
|
|
opt = Momentum(params=get_param_groups(network), |
|
|
opt = Momentum(params=get_param_groups(network), |
|
|
learning_rate=Tensor(lr), |
|
|
learning_rate=Tensor(lr), |
|
|
momentum=args.momentum, |
|
|
momentum=args.momentum, |
|
|
@@ -281,7 +283,6 @@ def train(): |
|
|
if i == 10: |
|
|
if i == 10: |
|
|
profiler.analyse() |
|
|
profiler.analyse() |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
args.logger.info('==========end training===============') |
|
|
args.logger.info('==========end training===============') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|