| @@ -15,6 +15,7 @@ | |||||
| """train_imagenet.""" | """train_imagenet.""" | ||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import numpy as np | |||||
| from dataset import create_dataset | from dataset import create_dataset | ||||
| from lr_generator import get_lr | from lr_generator import get_lr | ||||
| from config import config | from config import config | ||||
| @@ -45,6 +46,7 @@ if __name__ == '__main__': | |||||
| target = args_opt.device_target | target = args_opt.device_target | ||||
| ckpt_save_dir = config.save_checkpoint_path | ckpt_save_dir = config.save_checkpoint_path | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | ||||
| np.random.seed(1) | |||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||
| @@ -15,6 +15,7 @@ | |||||
| """train_imagenet.""" | """train_imagenet.""" | ||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import numpy as np | |||||
| from dataset import create_dataset | from dataset import create_dataset | ||||
| from lr_generator import get_lr | from lr_generator import get_lr | ||||
| from config import config | from config import config | ||||
| @@ -48,6 +49,7 @@ if __name__ == '__main__': | |||||
| target = args_opt.device_target | target = args_opt.device_target | ||||
| ckpt_save_dir = config.save_checkpoint_path | ckpt_save_dir = config.save_checkpoint_path | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | ||||
| np.random.seed(1) | |||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||