From: @caojian05 Reviewed-by: @oacjiewen,@wuxuejian Signed-off-by: @wuxuejianpull/15268/MERGE
| @@ -73,6 +73,7 @@ Dataset used: | |||||
| ├── README.md // descriptions about FCN | ├── README.md // descriptions about FCN | ||||
| ├── scripts | ├── scripts | ||||
| ├── run_train.sh | ├── run_train.sh | ||||
| ├── run_standalone_train.sh | |||||
| ├── run_eval.sh | ├── run_eval.sh | ||||
| ├── build_data.sh | ├── build_data.sh | ||||
| ├── src | ├── src | ||||
| @@ -114,13 +115,13 @@ Dataset used: | |||||
| # model | # model | ||||
| 'model': 'FCN8s', | 'model': 'FCN8s', | ||||
| 'ckpt_vgg16': '/data/workspace/mindspore_dataset/FCN/FCN/model/0-150_5004.ckpt', | |||||
| 'ckpt_pre_trained': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt', | |||||
| 'ckpt_vgg16': '', | |||||
| 'ckpt_pre_trained': '', | |||||
| # train | # train | ||||
| 'save_steps': 330, | 'save_steps': 330, | ||||
| 'keep_checkpoint_max': 500, | |||||
| 'train_dir': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/', | |||||
| 'keep_checkpoint_max': 5, | |||||
| 'ckpt_dir': './ckpt', | |||||
| ``` | ``` | ||||
| 如需获取更多信息,请查看`config.py`. | 如需获取更多信息,请查看`config.py`. | ||||
| @@ -281,7 +282,7 @@ Dataset used: | |||||
| if args.rank == 0: | if args.rank == 0: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps, | ||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | keep_checkpoint_max=cfg.keep_checkpoint_max) | ||||
| ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.train_dir, config=config_ck) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.ckpt_dir, config=config_ck) | |||||
| cbs.append(ckpoint_cb) | cbs.append(ckpoint_cb) | ||||
| model.train(cfg.train_epochs, dataset, callbacks=cbs) | model.train(cfg.train_epochs, dataset, callbacks=cbs) | ||||
| @@ -0,0 +1,38 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train.sh [device_num]" | |||||
| exit 1 | |||||
| fi | |||||
| export DEVICE_ID=$1 | |||||
| train_path=train_standalone${DEVICE_ID} | |||||
| if [ -d ${train_path} ]; then | |||||
| rm -rf ${train_path} | |||||
| fi | |||||
| mkdir -p ${train_path} | |||||
| cp -r ./src ${train_path} | |||||
| cp ./train.py ${train_path} | |||||
| echo "start training for device $DEVICE_ID" | |||||
| cd ${train_path}|| exit | |||||
| python train.py --device_id=${DEVICE_ID} > log 2>&1 & | |||||
| cd .. | |||||
| @@ -38,11 +38,11 @@ FCN8s_VOC2012_cfg = edict({ | |||||
| # model | # model | ||||
| 'model': 'FCN8s', | 'model': 'FCN8s', | ||||
| 'ckpt_vgg16': '/data/workspace/mindspore_dataset/FCN/FCN/model/0-150_5004.ckpt', | |||||
| 'ckpt_pre_trained': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt', | |||||
| 'ckpt_vgg16': '', | |||||
| 'ckpt_pre_trained': '', | |||||
| # train | # train | ||||
| 'save_steps': 330, | 'save_steps': 330, | ||||
| 'keep_checkpoint_max': 500, | |||||
| 'train_dir': '/data/workspace/mindspore_dataset/FCN/FCN/model_new/', | |||||
| 'keep_checkpoint_max': 5, | |||||
| 'ckpt_dir': './ckpt', | |||||
| }) | }) | ||||
| @@ -128,7 +128,7 @@ def train(): | |||||
| if args.rank == 0: | if args.rank == 0: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps, | config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_steps, | ||||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | keep_checkpoint_max=cfg.keep_checkpoint_max) | ||||
| ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.train_dir, config=config_ck) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.ckpt_dir, config=config_ck) | |||||
| cbs.append(ckpoint_cb) | cbs.append(ckpoint_cb) | ||||
| model.train(cfg.train_epochs, dataset, callbacks=cbs) | model.train(cfg.train_epochs, dataset, callbacks=cbs) | ||||