|
|
|
@@ -135,6 +135,14 @@ def network_init(args): |
|
|
|
devid = int(os.getenv('DEVICE_ID', '0')) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, |
|
|
|
device_target=args.device_target, save_graphs=False, device_id=devid) |
|
|
|
|
|
|
|
profiler = None |
|
|
|
if args.need_profiler: |
|
|
|
from mindspore.profiler import Profiler |
|
|
|
profiling_dir = os.path.join("profiling", |
|
|
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) |
|
|
|
profiler = Profiler(output_path=profiling_dir, is_detail=True, is_show_op_path=True) |
|
|
|
|
|
|
|
# init distributed |
|
|
|
if args.is_distributed: |
|
|
|
if args.device_target == "Ascend": |
|
|
|
@@ -155,6 +163,7 @@ def network_init(args): |
|
|
|
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) |
|
|
|
args.logger = get_logger(args.outputs_dir, args.rank) |
|
|
|
args.logger.save_args(args) |
|
|
|
return profiler |
|
|
|
|
|
|
|
|
|
|
|
def parallel_init(args): |
|
|
|
@@ -169,10 +178,7 @@ def parallel_init(args): |
|
|
|
def train(): |
|
|
|
"""Train function.""" |
|
|
|
args = parse_args() |
|
|
|
network_init(args) |
|
|
|
if args.need_profiler: |
|
|
|
from mindspore.profiler import Profiler |
|
|
|
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) |
|
|
|
profiler = network_init(args) |
|
|
|
|
|
|
|
loss_meter = AverageMeter('loss') |
|
|
|
parallel_init(args) |
|
|
|
|