From: @c_34 Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -160,15 +160,15 @@ max_text_length": 23, # max number of digits in each | |||||
| ### [Training](#contents) | ### [Training](#contents) | ||||
| - Run `run_standalone_train.sh` for non-distributed training of CRNN model, either on Ascend or on GPU. | |||||
| - Run `run_standalone_train.sh` for non-distributed training of CRNN model, only support Ascend now. | |||||
| ``` bash | ``` bash | ||||
| bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] | |||||
| bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional) | |||||
| ``` | ``` | ||||
| #### [Distributed Training](#contents) | #### [Distributed Training](#contents) | ||||
| - Run `run_distribute_train.sh` for distributed training of WarpCTC model on Ascend. | |||||
| - Run `run_distribute_train.sh` for distributed training of CRNN model on Ascend. | |||||
| ``` bash | ``` bash | ||||
| bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH] | bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH] | ||||
| @@ -188,7 +188,7 @@ Epoch time: 2743.688s, per step time: 0.097s | |||||
| - Run `run_eval.sh` for evaluation. | - Run `run_eval.sh` for evaluation. | ||||
| ``` bash | ``` bash | ||||
| bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM] | |||||
| bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional) | |||||
| ``` | ``` | ||||
| Check the `eval/log.txt` and you will get outputs as following: | Check the `eval/log.txt` and you will get outputs as following: | ||||
| @@ -232,7 +232,7 @@ result: {'CRNNAccuracy': (0.806)} | |||||
| | Dataset | SVT | IIIT5K | | | Dataset | SVT | IIIT5K | | ||||
| | batch_size | 1 | 1 | | | batch_size | 1 | 1 | | ||||
| | outputs | ACC | ACC | | | outputs | ACC | ACC | | ||||
| | Accuracy | 80.9% | 80.6% | | |||||
| | Accuracy | 80.8% | 79.7% | | |||||
| | Model for inference | 83M (.ckpt file) | 83M (.ckpt file) | | | Model for inference | 83M (.ckpt file) | 83M (.ckpt file) | | ||||
| ## [Description of Random Situation](#contents) | ## [Description of Random Situation](#contents) | ||||
| @@ -14,8 +14,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# != 4 ]; then | |||||
| echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]" | |||||
| if [ $# != 4 ] && [ $# != 3 ]; then | |||||
| echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM](optional) " | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -30,7 +30,12 @@ get_real_path() { | |||||
| DATASET_NAME=$1 | DATASET_NAME=$1 | ||||
| PATH1=$(get_real_path $2) | PATH1=$(get_real_path $2) | ||||
| PATH2=$(get_real_path $3) | PATH2=$(get_real_path $3) | ||||
| PLATFORM=$4 | |||||
| if [ $# == 4 ]; then | |||||
| PLATFORM=$4 | |||||
| else | |||||
| PLATFORM="Ascend" | |||||
| fi | |||||
| if [ ! -d $PATH1 ]; then | if [ ! -d $PATH1 ]; then | ||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | echo "error: DATASET_PATH=$PATH1 is not a directory" | ||||
| @@ -14,8 +14,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# != 3 ]; then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]" | |||||
| if [ $# != 3 ] && [ $# != 2 ]; then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM](optional)" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -29,7 +29,11 @@ get_real_path() { | |||||
| DATASET_NAME=$1 | DATASET_NAME=$1 | ||||
| PATH1=$(get_real_path $2) | PATH1=$(get_real_path $2) | ||||
| PLATFORM=$3 | |||||
| if [ $# == 3 ]; then | |||||
| PLATFORM=$3 | |||||
| else | |||||
| PLATFORM="Ascend" | |||||
| fi | |||||
| if [ ! -d $PATH1 ]; then | if [ ! -d $PATH1 ]; then | ||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | echo "error: DATASET_PATH=$PATH1 is not a directory" | ||||
| @@ -58,7 +62,7 @@ run_gpu() { | |||||
| if [ -d "train" ]; then | if [ -d "train" ]; then | ||||
| rm -rf ./train | rm -rf ./train | ||||
| fi | fi | ||||
| WORKDIR=./train$(DEVICE_ID) | |||||
| WORKDIR=./train${DEVICE_ID} | |||||
| mkdir $WORKDIR | mkdir $WORKDIR | ||||
| cp ../*.py $WORKDIR | cp ../*.py $WORKDIR | ||||
| cp -r ../src $WORKDIR | cp -r ../src $WORKDIR | ||||
| @@ -34,8 +34,8 @@ set_seed(1) | |||||
| parser = argparse.ArgumentParser(description="crnn training") | parser = argparse.ArgumentParser(description="crnn training") | ||||
| parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") | parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") | ||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') | parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') | ||||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||||
| help='Running platform, choose from Ascend, GPU, and default is Ascend.') | |||||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'], | |||||
| help='Running platform, only support Ascend now. Default is Ascend.') | |||||
| parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase") | parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase") | ||||
| parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) | parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) | ||||
| parser.set_defaults(run_distribute=False) | parser.set_defaults(run_distribute=False) | ||||
| @@ -92,7 +92,7 @@ if __name__ == '__main__': | |||||
| model = Model(net) | model = Model(net) | ||||
| # define callbacks | # define callbacks | ||||
| callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] | callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] | ||||
| if config.save_checkpoint: | |||||
| if config.save_checkpoint and rank == 0: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, | ||||
| keep_checkpoint_max=config.keep_checkpoint_max) | keep_checkpoint_max=config.keep_checkpoint_max) | ||||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | ||||