Browse Source

!13914 modified train.py in yolov3_darknet53 network

From: @shuzigood
Reviewed-by: @ouwenchang,@wuxuejian
Signed-off-by: @wuxuejian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
e16630c844
1 changed files with 10 additions and 4 deletions
  1. +10
    -4
      model_zoo/official/cv/yolov3_darknet53/train.py

+ 10
- 4
model_zoo/official/cv/yolov3_darknet53/train.py View File

@@ -135,6 +135,14 @@ def network_init(args):
devid = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False, device_id=devid)

profiler = None
if args.need_profiler:
from mindspore.profiler import Profiler
profiling_dir = os.path.join("profiling",
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
profiler = Profiler(output_path=profiling_dir, is_detail=True, is_show_op_path=True)

# init distributed
if args.is_distributed:
if args.device_target == "Ascend":
@@ -155,6 +163,7 @@ def network_init(args):
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
return profiler


def parallel_init(args):
@@ -169,10 +178,7 @@ def parallel_init(args):
def train():
"""Train function."""
args = parse_args()
network_init(args)
if args.need_profiler:
from mindspore.profiler import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
profiler = network_init(args)

loss_meter = AverageMeter('loss')
parallel_init(args)


Loading…
Cancel
Save