Browse Source

add relative path support in resnet50_cifar10 example

tags/v0.3.0-alpha
gengdongjie 5 years ago
parent
commit
5375b66cc9
2 changed files with 17 additions and 4 deletions
  1. +15
    -3
      example/resnet50_cifar10/run_infer.sh
  2. +2
    -1
      example/resnet50_cifar10/train.py

+ 15
- 3
example/resnet50_cifar10/run_infer.sh View File

@@ -20,13 +20,25 @@ then
exit 1
fi

if [ ! -d $1 ]
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}

PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)


if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi

if [ ! -f $2 ]
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
exit 1
@@ -48,5 +60,5 @@ cp *.sh ./infer
cd ./infer || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$1 --checkpoint_path=$2 &> log &
python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

+ 2
- 1
example/resnet50_cifar10/train.py View File

@@ -77,7 +77,8 @@ if __name__ == '__main__':
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)

model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2",
keep_batchnorm_fp32=False)

time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()


Loading…
Cancel
Save