|
|
|
@@ -15,7 +15,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
echo "==============================================================================================================" |
|
|
|
echo "Please run the scipt as: " |
|
|
|
echo "Please run the script as: " |
|
|
|
echo "sh run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH" |
|
|
|
echo "for example: sh run_standalone_train.sh Ascend 0 52 /path/ende-l128-mindrecord00" |
|
|
|
echo "It is better to use absolute path." |
|
|
|
@@ -31,17 +31,36 @@ DEVICE_ID=$2 |
|
|
|
EPOCH_SIZE=$3 |
|
|
|
DATA_PATH=$4 |
|
|
|
|
|
|
|
python train.py \ |
|
|
|
--distribute="false" \ |
|
|
|
--epoch_size=$EPOCH_SIZE \ |
|
|
|
--device_target=$DEVICE_TARGET \ |
|
|
|
--device_id=$DEVICE_ID \ |
|
|
|
--enable_save_ckpt="true" \ |
|
|
|
--enable_lossscale="true" \ |
|
|
|
--do_shuffle="true" \ |
|
|
|
--checkpoint_path="" \ |
|
|
|
--save_checkpoint_steps=2500 \ |
|
|
|
--save_checkpoint_num=30 \ |
|
|
|
--data_path=$DATA_PATH \ |
|
|
|
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 & |
|
|
|
if [ $DEVICE_TARGET == 'Ascend' ];then |
|
|
|
python train.py \ |
|
|
|
--distribute="false" \ |
|
|
|
--epoch_size=$EPOCH_SIZE \ |
|
|
|
--device_target=$DEVICE_TARGET \ |
|
|
|
--device_id=$DEVICE_ID \ |
|
|
|
--enable_save_ckpt="true" \ |
|
|
|
--enable_lossscale="true" \ |
|
|
|
--do_shuffle="true" \ |
|
|
|
--checkpoint_path="" \ |
|
|
|
--save_checkpoint_steps=2500 \ |
|
|
|
--save_checkpoint_num=30 \ |
|
|
|
--data_path=$DATA_PATH \ |
|
|
|
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 & |
|
|
|
elif [ $DEVICE_TARGET == 'GPU' ];then |
|
|
|
export CUDA_VISIBLE_DEVICES="$2" |
|
|
|
|
|
|
|
python train.py \ |
|
|
|
--distribute="false" \ |
|
|
|
--epoch_size=$EPOCH_SIZE \ |
|
|
|
--device_target=$DEVICE_TARGET \ |
|
|
|
--enable_save_ckpt="true" \ |
|
|
|
--enable_lossscale="true" \ |
|
|
|
--do_shuffle="true" \ |
|
|
|
--checkpoint_path="" \ |
|
|
|
--save_checkpoint_steps=2500 \ |
|
|
|
--save_checkpoint_num=30 \ |
|
|
|
--data_path=$DATA_PATH \ |
|
|
|
--bucket_boundaries=[16,32,48,64,128] > log.txt 2>&1 & |
|
|
|
else |
|
|
|
echo "Not supported device target." |
|
|
|
fi |
|
|
|
cd .. |