Browse Source

vgg16 support imagenet dataset on Ascend

tags/v1.0.0
CaoJian 5 years ago
parent
commit
41e6ceaa72
2 changed files with 21 additions and 5 deletions
  1. +18
    -5
      model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh
  2. +3
    -0
      model_zoo/official/cv/vgg16/src/vgg.py

+ 18
- 5
model_zoo/official/cv/vgg16/scripts/run_distribute_train.sh View File

@@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================

if [ $# != 2 ]
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]"
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]"
exit 1
fi

@@ -32,6 +32,19 @@ then
exit 1
fi


dataset_type='cifar10'
if [ $# == 3 ]
then
if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
dataset_type=$3
fi


export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$1
@@ -45,8 +58,8 @@ do
cp *.py ./train_parallel$i
cp -r src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
env > env.log
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log &
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log &
cd ..
done
done

+ 3
- 0
model_zoo/official/cv/vgg16/src/vgg.py View File

@@ -139,5 +139,8 @@ def vgg16(num_classes=1000, args=None, phase="train"):
>>> vgg16(num_classes=1000, args=args)
"""

if args is None:
from .config import cifar_cfg
args = cifar_cfg
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
return net

Loading…
Cancel
Save