Browse Source

fix xception and textrcnn

tags/v1.1.0
Yanjun Peng 5 years ago
parent
commit
f708e4b108
2 changed files with 15 additions and 15 deletions
  1. +14
    -14
      model_zoo/official/cv/xception/train.py
  2. +1
    -1
      model_zoo/official/nlp/textrcnn/scripts/run_eval.sh

+ 14
- 14
model_zoo/official/cv/xception/train.py View File

@@ -98,24 +98,24 @@ if __name__ == '__main__':
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = parser.parse_args()

# init distributed
if args_opt.is_distributed:
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
rank = get_rank()
group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
init()
else:
rank = 0
group_size = 1
context.set_context(device_id=0)

if args_opt.device_target == "Ascend":
#train on Ascend
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)

# init distributed
if args_opt.is_distributed:
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
rank = get_rank()
group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
else:
rank = 0
group_size = 1
context.set_context(device_id=0)

# define network
net = xception(class_num=config.class_num)
net.to_float(mstype.float16)


+ 1
- 1
model_zoo/official/nlp/textrcnn/scripts/run_eval.sh View File

@@ -17,4 +17,4 @@ ulimit -u unlimited

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 &
python ${BASEPATH}/../eval.py --ckpt_path $1 > ./eval.log 2>&1 &

Loading…
Cancel
Save