|
|
|
@@ -14,9 +14,9 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
if [ $# != 2 ] |
|
|
|
if [ $# != 2 ] && [ $# != 3 ] |
|
|
|
then |
|
|
|
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]" |
|
|
|
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]" |
|
|
|
exit 1 |
|
|
|
fi |
|
|
|
|
|
|
|
@@ -32,6 +32,19 @@ then |
|
|
|
exit 1 |
|
|
|
fi |
|
|
|
|
|
|
|
|
|
|
|
dataset_type='cifar10' |
|
|
|
if [ $# == 3 ] |
|
|
|
then |
|
|
|
if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ] |
|
|
|
then |
|
|
|
echo "error: the selected dataset is neither cifar10 nor imagenet2012" |
|
|
|
exit 1 |
|
|
|
fi |
|
|
|
dataset_type=$3 |
|
|
|
fi |
|
|
|
|
|
|
|
|
|
|
|
export DEVICE_NUM=8 |
|
|
|
export RANK_SIZE=8 |
|
|
|
export RANK_TABLE_FILE=$1 |
|
|
|
@@ -45,8 +58,8 @@ do |
|
|
|
cp *.py ./train_parallel$i |
|
|
|
cp -r src ./train_parallel$i |
|
|
|
cd ./train_parallel$i || exit |
|
|
|
echo "start training for rank $RANK_ID, device $DEVICE_ID" |
|
|
|
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" |
|
|
|
env > env.log |
|
|
|
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log & |
|
|
|
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log & |
|
|
|
cd .. |
|
|
|
done |
|
|
|
done |