| @@ -135,6 +135,14 @@ def network_init(args): | |||||
| devid = int(os.getenv('DEVICE_ID', '0')) | devid = int(os.getenv('DEVICE_ID', '0')) | ||||
| context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, | ||||
| device_target=args.device_target, save_graphs=False, device_id=devid) | device_target=args.device_target, save_graphs=False, device_id=devid) | ||||
| profiler = None | |||||
| if args.need_profiler: | |||||
| from mindspore.profiler import Profiler | |||||
| profiling_dir = os.path.join("profiling", | |||||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||||
| profiler = Profiler(output_path=profiling_dir, is_detail=True, is_show_op_path=True) | |||||
| # init distributed | # init distributed | ||||
| if args.is_distributed: | if args.is_distributed: | ||||
| if args.device_target == "Ascend": | if args.device_target == "Ascend": | ||||
| @@ -155,6 +163,7 @@ def network_init(args): | |||||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | ||||
| args.logger = get_logger(args.outputs_dir, args.rank) | args.logger = get_logger(args.outputs_dir, args.rank) | ||||
| args.logger.save_args(args) | args.logger.save_args(args) | ||||
| return profiler | |||||
| def parallel_init(args): | def parallel_init(args): | ||||
| @@ -169,10 +178,7 @@ def parallel_init(args): | |||||
| def train(): | def train(): | ||||
| """Train function.""" | """Train function.""" | ||||
| args = parse_args() | args = parse_args() | ||||
| network_init(args) | |||||
| if args.need_profiler: | |||||
| from mindspore.profiler import Profiler | |||||
| profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) | |||||
| profiler = network_init(args) | |||||
| loss_meter = AverageMeter('loss') | loss_meter = AverageMeter('loss') | ||||
| parallel_init(args) | parallel_init(args) | ||||