diff --git a/model_zoo/official/nlp/transformer/scripts/run_standalone_train.sh b/model_zoo/official/nlp/transformer/scripts/run_standalone_train.sh index be9c22ee0a..c8d3b6eca4 100644 --- a/model_zoo/official/nlp/transformer/scripts/run_standalone_train.sh +++ b/model_zoo/official/nlp/transformer/scripts/run_standalone_train.sh @@ -15,7 +15,7 @@ # ============================================================================ echo "==============================================================================================================" -echo "Please run the scipt as: " +echo "Please run the script as: " echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH" echo "for example: sh run_standalone_train.sh Ascend 0 52 /path/ende-l128-mindrecord00" echo "It is better to use absolute path." @@ -31,17 +31,36 @@ DEVICE_ID=$2 EPOCH_SIZE=$3 DATA_PATH=$4 -python train.py \ - --distribute="false" \ - --epoch_size=$EPOCH_SIZE \ - --device_target=$DEVICE_TARGET \ - --device_id=$DEVICE_ID \ - --enable_save_ckpt="true" \ - --enable_lossscale="true" \ - --do_shuffle="true" \ - --checkpoint_path="" \ - --save_checkpoint_steps=2500 \ - --save_checkpoint_num=30 \ - --data_path=$DATA_PATH \ - --bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 & +if [ $DEVICE_TARGET == 'Ascend' ];then + python train.py \ + --distribute="false" \ + --epoch_size=$EPOCH_SIZE \ + --device_target=$DEVICE_TARGET \ + --device_id=$DEVICE_ID \ + --enable_save_ckpt="true" \ + --enable_lossscale="true" \ + --do_shuffle="true" \ + --checkpoint_path="" \ + --save_checkpoint_steps=2500 \ + --save_checkpoint_num=30 \ + --data_path=$DATA_PATH \ + --bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 & +elif [ $DEVICE_TARGET == 'GPU' ];then + export CUDA_VISIBLE_DEVICES="$2" + + python train.py \ + --distribute="false" \ + --epoch_size=$EPOCH_SIZE \ + --device_target=$DEVICE_TARGET \ + --enable_save_ckpt="true" \ + --enable_lossscale="true" \ + --do_shuffle="true" \ + --checkpoint_path="" \ + --save_checkpoint_steps=2500 \ + --save_checkpoint_num=30 \ + --data_path=$DATA_PATH \ + --bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 & +else + echo "Not supported device target." +fi cd .. diff --git a/model_zoo/official/nlp/transformer/train.py b/model_zoo/official/nlp/transformer/train.py index af8557d57e..21049f8644 100644 --- a/model_zoo/official/nlp/transformer/train.py +++ b/model_zoo/official/nlp/transformer/train.py @@ -122,7 +122,10 @@ def run_transformer_train(): """ parser = argparse_init() args, _ = parser.parse_known_args() - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) + if args.device_target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) + else: + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) if args.distribute == "true":