|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
"""train_imagenet.""" |
|
|
|
import os |
|
|
|
import argparse |
|
|
|
import numpy as np |
|
|
|
from dataset import create_dataset |
|
|
|
from lr_generator import get_lr |
|
|
|
from config import config |
|
|
|
@@ -45,6 +46,7 @@ if __name__ == '__main__': |
|
|
|
target = args_opt.device_target |
|
|
|
ckpt_save_dir = config.save_checkpoint_path |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) |
|
|
|
np.random.seed(1) |
|
|
|
if not args_opt.do_eval and args_opt.run_distribute: |
|
|
|
if target == "Ascend": |
|
|
|
device_id = int(os.getenv('DEVICE_ID')) |
|
|
|
|