|
|
|
@@ -147,6 +147,8 @@ def parse_args(cloud_args=None): |
|
|
|
args.lr_epochs = list(map(int, args.lr_epochs.split(','))) |
|
|
|
args.image_size = list(map(int, args.image_size.split(','))) |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, |
|
|
|
device_target=args.platform, save_graphs=False) |
|
|
|
# init distributed |
|
|
|
if args.is_distributed: |
|
|
|
init() |
|
|
|
@@ -190,8 +192,6 @@ def merge_args(args, cloud_args): |
|
|
|
def train(cloud_args=None): |
|
|
|
"""training process""" |
|
|
|
args = parse_args(cloud_args) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, |
|
|
|
device_target=args.platform, save_graphs=False) |
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit(): |
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) |
|
|
|
|
|
|
|
|