diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py index 275f7188a7..323695ae29 100755 --- a/example/resnet50_cifar10/train.py +++ b/example/resnet50_cifar10/train.py @@ -15,6 +15,7 @@ """train_imagenet.""" import os import argparse +import numpy as np from dataset import create_dataset from lr_generator import get_lr from config import config @@ -45,6 +46,7 @@ if __name__ == '__main__': target = args_opt.device_target ckpt_save_dir = config.save_checkpoint_path context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + np.random.seed(1) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) diff --git a/example/resnet50_imagenet2012/train.py b/example/resnet50_imagenet2012/train.py index a76de78f6d..abb55731dc 100755 --- a/example/resnet50_imagenet2012/train.py +++ b/example/resnet50_imagenet2012/train.py @@ -15,6 +15,7 @@ """train_imagenet.""" import os import argparse +import numpy as np from dataset import create_dataset from lr_generator import get_lr from config import config @@ -48,6 +49,7 @@ if __name__ == '__main__': target = args_opt.device_target ckpt_save_dir = config.save_checkpoint_path context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + np.random.seed(1) if not args_opt.do_eval and args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID'))