From: @zhanghuiyao Reviewed-by: @oacjiewen,@c_34 Signed-off-by: @c_34pull/15822/MERGE
| @@ -84,7 +84,6 @@ The entire code structure is as following: | |||||
| │ │ ├── head.py // head unit | │ │ ├── head.py // head unit | ||||
| │ │ ├── resnet.py // resnet architecture | │ │ ├── resnet.py // resnet architecture | ||||
| │ ├── callback_factory.py // callback logging | │ ├── callback_factory.py // callback logging | ||||
| │ ├── config.py // parameter configuration | |||||
| │ ├── custom_dataset.py // custom dataset and sampler | │ ├── custom_dataset.py // custom dataset and sampler | ||||
| │ ├── custom_net.py // custom cell define | │ ├── custom_net.py // custom cell define | ||||
| │ ├── dataset_factory.py // creating dataset | │ ├── dataset_factory.py // creating dataset | ||||
| @@ -94,6 +93,15 @@ The entire code structure is as following: | |||||
| │ ├── lrsche_factory.py // learning rate schedule | │ ├── lrsche_factory.py // learning rate schedule | ||||
| │ ├── me_init.py // network parameter init method | │ ├── me_init.py // network parameter init method | ||||
| │ ├── metric_factory.py // metric fc layer | │ ├── metric_factory.py // metric fc layer | ||||
| ── utils | |||||
| │ ├── __init__.py // init file | |||||
| │ ├── config.py // parameter analysis | |||||
| │ ├── device_adapter.py // device adapter | |||||
| │ ├── local_adapter.py // local adapter | |||||
| │ ├── moxing_adapter.py // moxing adapter | |||||
| ├─ base_config.yaml // parameter configuration | |||||
| ├─ beta_config.yaml // parameter configuration | |||||
| ├─ inference_config.yaml // parameter configuration | |||||
| ├─ train.py // training scripts | ├─ train.py // training scripts | ||||
| ├─ eval.py // evaluation scripts | ├─ eval.py // evaluation scripts | ||||
| └─ export.py // export air model | └─ export.py // export air model | ||||
| @@ -163,6 +171,47 @@ The entire code structure is as following: | |||||
| sh run_distribute_train_beta.sh ./rank_table_8p.json | sh run_distribute_train_beta.sh ./rank_table_8p.json | ||||
| ``` | ``` | ||||
| - ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows) | |||||
| - base model | |||||
| ```python | |||||
| # (1) Add "config_path='/path_to_code/base_config.yaml'" on the website UI interface. | |||||
| # (2) Perform a or b. | |||||
| # a. Set "enable_modelarts=True" on base_config.yaml file. | |||||
| # Set "is_distributed=1" on base_config.yaml file. | |||||
| # Set other parameters on base_config.yaml file you need. | |||||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||||
| # Add "is_distributed=1" on the website UI interface. | |||||
| # Add other parameters on the website UI interface. | |||||
| # (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) | |||||
| # (4) Set the code directory to "/path/FaceRecognition" on the website UI interface. | |||||
| # (5) Set the startup file to "train.py" on the website UI interface. | |||||
| # (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||||
| # (7) Create your job. | |||||
| ``` | |||||
| - beta model | |||||
| ```python | |||||
| # (1) Copy or upload your trained model to S3 bucket. | |||||
| # (2) Add "config_path='/path_to_code/beta_config.yaml'" on the website UI interface. | |||||
| # (3) Perform a or b. | |||||
| # a. Set "enable_modelarts=True" on beta_config.yaml file. | |||||
| # Set "is_distributed=1" on base_config.yaml file. | |||||
| # Set "pretrained='/cache/checkpoint_path/model.ckpt'" on beta_config.yaml file. | |||||
| # Set "checkpoint_url=/The path of checkpoint in S3/" on beta_config.yaml file. | |||||
| # b. Add "enable_modelarts=True" on the website UI interface. | |||||
| # Add "is_distributed=1" on the website UI interface. | |||||
| # Add "pretrained='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file. | |||||
| # Add "checkpoint_url=/The path of checkpoint in S3/" on default_config.yaml file. | |||||
| # (4) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) | |||||
| # (5) Set the code directory to "/path/FaceRecognition" on the website UI interface. | |||||
| # (6) Set the startup file to "train.py" on the website UI interface. | |||||
| # (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||||
| # (8) Create your job. | |||||
| ``` | |||||
| You will get the loss value of each epoch as following in "./scripts/data_parallel_log_[DEVICE_ID]/outputs/logs/[TIME].log" or "./scripts/log_parallel_graph/face_recognition_[DEVICE_ID].log": | You will get the loss value of each epoch as following in "./scripts/data_parallel_log_[DEVICE_ID]/outputs/logs/[TIME].log" or "./scripts/log_parallel_graph/face_recognition_[DEVICE_ID].log": | ||||
| ```python | ```python | ||||
| @@ -188,6 +237,24 @@ sh run_eval.sh [USE_DEVICE_ID] | |||||
| You will get the result as following in "./scripts/log_inference/outputs/models/logs/[TIME].log": | You will get the result as following in "./scripts/log_inference/outputs/models/logs/[TIME].log": | ||||
| [test_dataset]: zj2jk=0.9495, jk2zj=0.9480, avg=0.9487 | [test_dataset]: zj2jk=0.9495, jk2zj=0.9480, avg=0.9487 | ||||
| If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start evaluation as follows: | |||||
| ```python | |||||
| # run evaluation on modelarts example | |||||
| # (1) Copy or upload your trained model to S3 bucket. | |||||
| # (2) Add "config_path='/path_to_code/inference_config.yaml'" on the website UI interface. | |||||
| # (3) Perform a or b. | |||||
| # a. Set "weight='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file. | |||||
| # Set "checkpoint_url=/The path of checkpoint in S3/" on default_config.yaml file. | |||||
| # b. Add "weight='/cache/checkpoint_path/model.ckpt'" on the website UI interface. | |||||
| # Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface. | |||||
| # (4) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.) | |||||
| # (5) Set the code directory to "/path/FaceRecognition" on the website UI interface. | |||||
| # (6) Set the startup file to "eval.py" on the website UI interface. | |||||
| # (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. | |||||
| # (8) Create your job. | |||||
| ``` | |||||
| ### Convert model | ### Convert model | ||||
| If you want to infer the network on Ascend 310, you should convert the model to AIR: | If you want to infer the network on Ascend 310, you should convert the model to AIR: | ||||
| @@ -0,0 +1,76 @@ | |||||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||||
| enable_modelarts: False | |||||
| # Url for modelarts | |||||
| data_url: "" | |||||
| train_url: "" | |||||
| checkpoint_url: "" | |||||
| # Path for local | |||||
| data_path: "/cache/data" | |||||
| output_path: "/cache/train" | |||||
| load_path: "/cache/checkpoint_path" | |||||
| device_target: "Ascend" | |||||
| enable_profiling: False | |||||
| # ============================================================================== | |||||
| # Training options | |||||
| train_stage: "base" | |||||
| is_distributed: 1 | |||||
| # dataset related | |||||
| data_dir: "/cache/data/face_recognition_dataset/train_dataset/" | |||||
| num_classes: 1 | |||||
| per_batch_size: 192 | |||||
| need_modelarts_dataset_unzip: True | |||||
| # network structure related | |||||
| backbone: "r100" | |||||
| use_se: 1 | |||||
| emb_size: 512 | |||||
| act_type: "relu" | |||||
| fp16: 1 | |||||
| pre_bn: 1 | |||||
| inference: 0 | |||||
| use_drop: 1 | |||||
| nc_16: 1 | |||||
| # loss related | |||||
| margin_a: 1.0 | |||||
| margin_b: 0.2 | |||||
| margin_m: 0.3 | |||||
| margin_s: 64 | |||||
| # optimizer related | |||||
| lr: 0.4 | |||||
| lr_scale: 1 | |||||
| lr_epochs: "8,14,18" | |||||
| weight_decay: 0.0002 | |||||
| momentum: 0.9 | |||||
| max_epoch: 20 | |||||
| pretrained: "" | |||||
| warmup_epochs: 2 | |||||
| # distributed parameter | |||||
| local_rank: 0 | |||||
| world_size: 1 | |||||
| model_parallel: 0 | |||||
| # logging related | |||||
| log_interval: 100 | |||||
| ckpt_path: "outputs" | |||||
| max_ckpts: -1 | |||||
| dynamic_init_loss_scale: 65536 | |||||
| ckpt_steps: 1000 | |||||
| --- | |||||
| # Help description for each configuration | |||||
| enable_modelarts: "Whether training on modelarts, default: False" | |||||
| data_url: "Url for modelarts" | |||||
| train_url: "Url for modelarts" | |||||
| data_path: "The location of the input data." | |||||
| output_path: "The location of the output file." | |||||
| device_target: 'Target device type' | |||||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||||
| train_stage: "Train stage, base or beta" | |||||
| is_distributed: "If multi device" | |||||
| @@ -0,0 +1,76 @@ | |||||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||||
| enable_modelarts: False | |||||
| # Url for modelarts | |||||
| data_url: "" | |||||
| train_url: "" | |||||
| checkpoint_url: "" | |||||
| # Path for local | |||||
| data_path: "/cache/data" | |||||
| output_path: "/cache/train" | |||||
| load_path: "/cache/checkpoint_path" | |||||
| device_target: "Ascend" | |||||
| enable_profiling: False | |||||
| # ============================================================================== | |||||
| # Training options | |||||
| train_stage: "beta" | |||||
| is_distributed: 1 | |||||
| # dataset related | |||||
| data_dir: "/cache/data/face_recognition_dataset/train_dataset/" | |||||
| num_classes: 1 | |||||
| per_batch_size: 192 | |||||
| need_modelarts_dataset_unzip: True | |||||
| # network structure related | |||||
| backbone: "r100" | |||||
| use_se: 0 | |||||
| emb_size: 256 | |||||
| act_type: "relu" | |||||
| fp16: 1 | |||||
| pre_bn: 0 | |||||
| inference: 0 | |||||
| use_drop: 1 | |||||
| nc_16: 1 | |||||
| # loss related | |||||
| margin_a: 1.0 | |||||
| margin_b: 0.2 | |||||
| margin_m: 0.3 | |||||
| margin_s: 64 | |||||
| # optimizer related | |||||
| lr: 0.04 | |||||
| lr_scale: 1 | |||||
| lr_epochs: "8,14,18" | |||||
| weight_decay: 0.0002 | |||||
| momentum: 0.9 | |||||
| max_epoch: 20 | |||||
| pretrained: "your_pretrained_model" | |||||
| warmup_epochs: 2 | |||||
| # distributed parameter | |||||
| local_rank: 0 | |||||
| world_size: 1 | |||||
| model_parallel: 0 | |||||
| # logging related | |||||
| log_interval: 100 | |||||
| ckpt_path: "outputs" | |||||
| max_ckpts: -1 | |||||
| dynamic_init_loss_scale: 65536 | |||||
| ckpt_steps: 1000 | |||||
| --- | |||||
| # Help description for each configuration | |||||
| enable_modelarts: "Whether training on modelarts, default: False" | |||||
| data_url: "Url for modelarts" | |||||
| train_url: "Url for modelarts" | |||||
| data_path: "The location of the input data." | |||||
| output_path: "The location of the output file." | |||||
| device_target: 'Target device type' | |||||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||||
| train_stage: "Train stage, base or beta" | |||||
| is_distributed: "If multi device" | |||||
| @@ -26,12 +26,14 @@ import mindspore.dataset as de | |||||
| from mindspore import Tensor, context | from mindspore import Tensor, context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.config import config_inference | |||||
| from src.backbone.resnet import get_backbone | from src.backbone.resnet import get_backbone | ||||
| from src.my_logging import get_logger | from src.my_logging import get_logger | ||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) | |||||
| from utils.config import config | |||||
| from utils.moxing_adapter import moxing_wrapper | |||||
| from utils.device_adapter import get_device_id, get_device_num, get_rank_id | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id()) | |||||
| class TxtDataset(): | class TxtDataset(): | ||||
| @@ -198,7 +200,61 @@ def l2normalize(features): | |||||
| l2norm[np.logical_and(l2norm >= 0, l2norm < epsilon)] = epsilon | l2norm[np.logical_and(l2norm >= 0, l2norm < epsilon)] = epsilon | ||||
| return features/l2norm | return features/l2norm | ||||
| def main(args): | |||||
| def modelarts_pre_process(): | |||||
| '''modelarts pre process function.''' | |||||
| def unzip(zip_file, save_dir): | |||||
| import zipfile | |||||
| s_time = time.time() | |||||
| if not os.path.exists(os.path.join(save_dir, "face_recognition_dataset")): | |||||
| zip_isexist = zipfile.is_zipfile(zip_file) | |||||
| if zip_isexist: | |||||
| fz = zipfile.ZipFile(zip_file, 'r') | |||||
| data_num = len(fz.namelist()) | |||||
| print("Extract Start...") | |||||
| print("unzip file num: {}".format(data_num)) | |||||
| i = 0 | |||||
| for file in fz.namelist(): | |||||
| if i % int(data_num / 100) == 0: | |||||
| print("unzip percent: {}%".format(i / int(data_num / 100)), flush=True) | |||||
| i += 1 | |||||
| fz.extract(file, save_dir) | |||||
| print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), | |||||
| int(int(time.time() - s_time) % 60))) | |||||
| print("Extract Done.") | |||||
| else: | |||||
| print("This is not zip.") | |||||
| else: | |||||
| print("Zip has been extracted.") | |||||
| if config.need_modelarts_dataset_unzip: | |||||
| zip_file_1 = os.path.join(config.data_path, "face_recognition_dataset.zip") | |||||
| save_dir_1 = os.path.join(config.data_path) | |||||
| sync_lock = "/tmp/unzip_sync.lock" | |||||
| # Each server contains 8 devices as most. | |||||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||||
| print("Zip file path: ", zip_file_1) | |||||
| print("Unzip file save dir: ", save_dir_1) | |||||
| unzip(zip_file_1, save_dir_1) | |||||
| print("===Finish extract data synchronization===") | |||||
| try: | |||||
| os.mknod(sync_lock) | |||||
| except IOError: | |||||
| pass | |||||
| while True: | |||||
| if os.path.exists(sync_lock): | |||||
| break | |||||
| time.sleep(1) | |||||
| print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) | |||||
| config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path) | |||||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||||
| def run_eval(args): | |||||
| '''run eval function.''' | |||||
| if not os.path.exists(args.test_dir): | if not os.path.exists(args.test_dir): | ||||
| args.logger.info('ERROR, test_dir is not exists, please set test_dir in config.py.') | args.logger.info('ERROR, test_dir is not exists, please set test_dir in config.py.') | ||||
| return 0 | return 0 | ||||
| @@ -317,17 +373,17 @@ def main(args): | |||||
| return 0 | return 0 | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| arg = config_inference | |||||
| arg.test_img_predix = [arg.test_dir, arg.test_dir] | |||||
| config.test_img_predix = [os.path.join(config.test_dir, 'test_dataset/'), | |||||
| os.path.join(config.test_dir, 'test_dataset/')] | |||||
| arg.test_img_list = [os.path.join(arg.test_dir, 'lists/jk_list.txt'), | |||||
| os.path.join(arg.test_dir, 'lists/zj_list.txt')] | |||||
| arg.dis_img_predix = [arg.test_dir,] | |||||
| arg.dis_img_list = [os.path.join(arg.test_dir, 'lists/dis_list.txt'),] | |||||
| config.test_img_list = [os.path.join(config.test_dir, 'lists/jk_list.txt'), | |||||
| os.path.join(config.test_dir, 'lists/zj_list.txt')] | |||||
| config.dis_img_predix = [os.path.join(config.test_dir, 'dis_dataset/'),] | |||||
| config.dis_img_list = [os.path.join(config.test_dir, 'lists/dis_list.txt'),] | |||||
| log_path = os.path.join(arg.ckpt_path, 'logs') | |||||
| arg.logger = get_logger(log_path, arg.local_rank) | |||||
| log_path = os.path.join(config.ckpt_path, 'logs') | |||||
| config.logger = get_logger(log_path, config.local_rank) | |||||
| arg.logger.info('Config: %s', pformat(arg)) | |||||
| config.logger.info('Config %s', pformat(config)) | |||||
| main(arg) | |||||
| run_eval(config) | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||||
| enable_modelarts: False | |||||
| # Url for modelarts | |||||
| data_url: "" | |||||
| train_url: "" | |||||
| checkpoint_url: "" | |||||
| # Path for local | |||||
| data_path: "/cache/data" | |||||
| output_path: "/cache/train" | |||||
| load_path: "/cache/checkpoint_path" | |||||
| device_target: "Ascend" | |||||
| enable_profiling: False | |||||
| # ============================================================================== | |||||
| # Training options | |||||
| # distributed parameter | |||||
| is_distributed: 0 | |||||
| local_rank: 0 | |||||
| world_size: 1 | |||||
| # test weight | |||||
| weight: 'your_test_model' | |||||
| test_dir: '/cache/data/face_recognition_dataset/' | |||||
| need_modelarts_dataset_unzip: True | |||||
| # model define | |||||
| backbone: "r100" | |||||
| use_se: 0 | |||||
| emb_size: 256 | |||||
| act_type: "relu" | |||||
| fp16: 1 | |||||
| pre_bn: 0 | |||||
| inference: 1 | |||||
| use_drop: 0 | |||||
| # test and dis batch size | |||||
| test_batch_size: 128 | |||||
| dis_batch_size: 512 | |||||
| # log | |||||
| log_interval: 100 | |||||
| ckpt_path: "outputs/models" | |||||
| # test and dis image list | |||||
| test_img_predix: "" | |||||
| test_img_list: "" | |||||
| dis_img_predix: "" | |||||
| dis_img_list: "" | |||||
| --- | |||||
| # Help description for each configuration | |||||
| enable_modelarts: "Whether training on modelarts, default: False" | |||||
| data_url: "Url for modelarts" | |||||
| train_url: "Url for modelarts" | |||||
| data_path: "The location of the input data." | |||||
| output_path: "The location of the output file." | |||||
| device_target: 'Target device type' | |||||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||||
| @@ -59,6 +59,7 @@ do | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | echo "start training for rank $RANK_ID, device $DEVICE_ID" | ||||
| env > ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log | env > ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log | ||||
| python ${EXECUTE_PATH}/../train.py \ | python ${EXECUTE_PATH}/../train.py \ | ||||
| --config_path=${EXECUTE_PATH}/../base_config.yaml \ | |||||
| --train_stage=base \ | --train_stage=base \ | ||||
| --is_distributed=1 &> ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log & | --is_distributed=1 &> ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log & | ||||
| done | done | ||||
| @@ -59,6 +59,7 @@ do | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | echo "start training for rank $RANK_ID, device $DEVICE_ID" | ||||
| env > ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log | env > ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log | ||||
| python ${EXECUTE_PATH}/../train.py \ | python ${EXECUTE_PATH}/../train.py \ | ||||
| --config_path=${EXECUTE_PATH}/../beta_config.yaml \ | |||||
| --train_stage=beta \ | --train_stage=beta \ | ||||
| --is_distributed=1 &> ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log & | --is_distributed=1 &> ${EXECUTE_PATH}/log_parallel_graph/face_recognition_$i.log & | ||||
| done | done | ||||
| @@ -41,6 +41,6 @@ mkdir ${EXECUTE_PATH}/log_inference | |||||
| cd ${EXECUTE_PATH}/log_inference || exit | cd ${EXECUTE_PATH}/log_inference || exit | ||||
| env > ${EXECUTE_PATH}/log_inference/face_recognition.log | env > ${EXECUTE_PATH}/log_inference/face_recognition.log | ||||
| python ${EXECUTE_PATH}/../eval.py &> ${EXECUTE_PATH}/log_inference/face_recognition.log & | |||||
| python ${EXECUTE_PATH}/../eval.py --config_path=${EXECUTE_PATH}/../inference_config.yaml &> ${EXECUTE_PATH}/log_inference/face_recognition.log & | |||||
| echo "[INFO] Start inference..." | echo "[INFO] Start inference..." | ||||
| @@ -46,6 +46,7 @@ cd ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID || exit | |||||
| echo "start training for rank $RANK_ID, device $USE_DEVICE_ID" | echo "start training for rank $RANK_ID, device $USE_DEVICE_ID" | ||||
| env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log | env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log | ||||
| python ${EXECUTE_PATH}/../train.py \ | python ${EXECUTE_PATH}/../train.py \ | ||||
| --config_path=${EXECUTE_PATH}/../base_config.yaml \ | |||||
| --train_stage=base \ | --train_stage=base \ | ||||
| --is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log & | --is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log & | ||||
| @@ -46,6 +46,7 @@ cd ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID || exit | |||||
| echo "start training for rank $RANK_ID, device $USE_DEVICE_ID" | echo "start training for rank $RANK_ID, device $USE_DEVICE_ID" | ||||
| env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log | env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log | ||||
| python ${EXECUTE_PATH}/../train.py \ | python ${EXECUTE_PATH}/../train.py \ | ||||
| --config_path=${EXECUTE_PATH}/../base_config.yaml \ | |||||
| --train_stage=beta \ | --train_stage=beta \ | ||||
| --is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log & | --is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log & | ||||
| @@ -1,148 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| #" :=========================================================================== | |||||
| """network config setting, will be used in train.py and eval.py.""" | |||||
| from easydict import EasyDict as edict | |||||
| config_base = edict({ | |||||
| # dataset related | |||||
| 'data_dir': "your_dataset_path", | |||||
| 'num_classes': 1, | |||||
| 'per_batch_size': 192, | |||||
| # network structure related | |||||
| 'backbone': 'r100', | |||||
| 'use_se': 1, | |||||
| 'emb_size': 512, | |||||
| 'act_type': 'relu', | |||||
| 'fp16': 1, | |||||
| 'pre_bn': 1, | |||||
| 'inference': 0, | |||||
| 'use_drop': 1, | |||||
| 'nc_16': 1, | |||||
| # loss related | |||||
| 'margin_a': 1.0, | |||||
| 'margin_b': 0.2, | |||||
| 'margin_m': 0.3, | |||||
| 'margin_s': 64, | |||||
| # optimizer related | |||||
| 'lr': 0.4, | |||||
| 'lr_scale': 1, | |||||
| 'lr_epochs': '8,14,18', | |||||
| 'weight_decay': 0.0002, | |||||
| 'momentum': 0.9, | |||||
| 'max_epoch': 20, | |||||
| 'pretrained': '', | |||||
| 'warmup_epochs': 2, | |||||
| # distributed parameter | |||||
| 'is_distributed': 1, | |||||
| 'local_rank': 0, | |||||
| 'world_size': 1, | |||||
| 'model_parallel': 0, | |||||
| # logging related | |||||
| 'log_interval': 100, | |||||
| 'ckpt_path': 'outputs', | |||||
| 'max_ckpts': -1, | |||||
| 'dynamic_init_loss_scale': 65536, | |||||
| 'ckpt_steps': 1000 | |||||
| }) | |||||
| config_beta = edict({ | |||||
| # dataset related | |||||
| 'data_dir': "your_dataset_path", | |||||
| 'num_classes': 1, | |||||
| 'per_batch_size': 192, | |||||
| # network structure related | |||||
| 'backbone': 'r100', | |||||
| 'use_se': 0, | |||||
| 'emb_size': 256, | |||||
| 'act_type': 'relu', | |||||
| 'fp16': 1, | |||||
| 'pre_bn': 0, | |||||
| 'inference': 0, | |||||
| 'use_drop': 1, | |||||
| 'nc_16': 1, | |||||
| # loss related | |||||
| 'margin_a': 1.0, | |||||
| 'margin_b': 0.2, | |||||
| 'margin_m': 0.3, | |||||
| 'margin_s': 64, | |||||
| # optimizer related | |||||
| 'lr': 0.04, | |||||
| 'lr_scale': 1, | |||||
| 'lr_epochs': '8,14,18', | |||||
| 'weight_decay': 0.0002, | |||||
| 'momentum': 0.9, | |||||
| 'max_epoch': 20, | |||||
| 'pretrained': 'your_pretrained_model', | |||||
| 'warmup_epochs': 2, | |||||
| # distributed parameter | |||||
| 'is_distributed': 1, | |||||
| 'local_rank': 0, | |||||
| 'world_size': 1, | |||||
| 'model_parallel': 0, | |||||
| # logging related | |||||
| 'log_interval': 100, | |||||
| 'ckpt_path': 'outputs', | |||||
| 'max_ckpts': -1, | |||||
| 'dynamic_init_loss_scale': 65536, | |||||
| 'ckpt_steps': 1000 | |||||
| }) | |||||
| config_inference = edict({ | |||||
| # distributed parameter | |||||
| 'is_distributed': 0, | |||||
| 'local_rank': 0, | |||||
| 'world_size': 1, | |||||
| # test weight | |||||
| 'weight': 'your_test_model', | |||||
| 'test_dir': 'your_dataset_path', | |||||
| # model define | |||||
| 'backbone': 'r100', | |||||
| 'use_se': 0, | |||||
| 'emb_size': 256, | |||||
| 'act_type': 'relu', | |||||
| 'fp16': 1, | |||||
| 'pre_bn': 0, | |||||
| 'inference': 1, | |||||
| 'use_drop': 0, | |||||
| # test and dis batch size | |||||
| 'test_batch_size': 128, | |||||
| 'dis_batch_size': 512, | |||||
| # log | |||||
| 'log_interval': 100, | |||||
| 'ckpt_path': 'outputs/models', | |||||
| # test and dis image list | |||||
| 'test_img_predix': '', | |||||
| 'test_img_list': '', | |||||
| 'dis_img_predix': '', | |||||
| 'dis_img_list': '' | |||||
| }) | |||||
| @@ -14,20 +14,19 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Face Recognition train.""" | """Face Recognition train.""" | ||||
| import os | import os | ||||
| import argparse | |||||
| import time | |||||
| import mindspore | import mindspore | ||||
| from mindspore.nn import Cell | from mindspore.nn import Cell | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.communication.management import get_group_size, init, get_rank | |||||
| from mindspore.communication.management import init | |||||
| from mindspore.nn.optim import Momentum | from mindspore.nn.optim import Momentum | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | ||||
| from mindspore.train.loss_scale_manager import DynamicLossScaleManager | from mindspore.train.loss_scale_manager import DynamicLossScaleManager | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.config import config_base, config_beta | |||||
| from src.my_logging import get_logger | from src.my_logging import get_logger | ||||
| from src.init_network import init_net | from src.init_network import init_net | ||||
| from src.dataset_factory import get_de_dataset | from src.dataset_factory import get_de_dataset | ||||
| @@ -37,10 +36,13 @@ from src.loss_factory import get_loss | |||||
| from src.lrsche_factory import warmup_step_list, list_to_gen | from src.lrsche_factory import warmup_step_list, list_to_gen | ||||
| from src.callback_factory import ProgressMonitor | from src.callback_factory import ProgressMonitor | ||||
| from utils.moxing_adapter import moxing_wrapper | |||||
| from utils.config import config | |||||
| from utils.device_adapter import get_device_id, get_device_num, get_rank_id | |||||
| mindspore.common.seed.set_seed(1) | mindspore.common.seed.set_seed(1) | ||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, | ||||
| device_id=devid, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) | |||||
| device_id=get_device_id(), reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) | |||||
| class DistributedHelper(Cell): | class DistributedHelper(Cell): | ||||
| '''DistributedHelper''' | '''DistributedHelper''' | ||||
| @@ -84,103 +86,13 @@ class BuildTrainNetwork(Cell): | |||||
| return loss | return loss | ||||
| def parse_args(): | |||||
| parser = argparse.ArgumentParser('MindSpore Face Recognition') | |||||
| parser.add_argument('--train_stage', type=str, default='base', help='train stage, base or beta') | |||||
| parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') | |||||
| args_opt_1, _ = parser.parse_known_args() | |||||
| return args_opt_1 | |||||
| if __name__ == "__main__": | |||||
| args_opt = parse_args() | |||||
| support_train_stage = ['base', 'beta'] | |||||
| if args_opt.train_stage.lower() not in support_train_stage: | |||||
| args.logger.info('support train stage is:{}, while yours is:{}'. | |||||
| format(support_train_stage, args_opt.train_stage)) | |||||
| raise ValueError('train stage not support.') | |||||
| args = config_base if args_opt.train_stage.lower() == 'base' else config_beta | |||||
| args.is_distributed = args_opt.is_distributed | |||||
| if args_opt.is_distributed: | |||||
| init() | |||||
| args.local_rank = get_rank() | |||||
| args.world_size = get_group_size() | |||||
| parallel_mode = ParallelMode.HYBRID_PARALLEL | |||||
| else: | |||||
| parallel_mode = ParallelMode.STAND_ALONE | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, | |||||
| device_num=args.world_size, gradients_mean=True) | |||||
| if not os.path.exists(args.data_dir): | |||||
| args.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py') | |||||
| raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py') | |||||
| args.lr_epochs = list(map(int, args.lr_epochs.split(','))) | |||||
| log_path = os.path.join(args.ckpt_path, 'logs') | |||||
| args.logger = get_logger(log_path, args.local_rank) | |||||
| if args.local_rank % 8 == 0: | |||||
| if not os.path.exists(args.ckpt_path): | |||||
| os.makedirs(args.ckpt_path) | |||||
| args.logger.info('args.world_size:{}'.format(args.world_size)) | |||||
| args.logger.info('args.local_rank:{}'.format(args.local_rank)) | |||||
| args.logger.info('args.lr:{}'.format(args.lr)) | |||||
| momentum = args.momentum | |||||
| weight_decay = args.weight_decay | |||||
| de_dataset, steps_per_epoch, num_classes = get_de_dataset(args) | |||||
| args.logger.info('de_dataset:{}'.format(de_dataset.get_dataset_size())) | |||||
| args.steps_per_epoch = steps_per_epoch | |||||
| args.num_classes = num_classes | |||||
| args.logger.info('loaded, nums: {}'.format(args.num_classes)) | |||||
| if args.nc_16 == 1: | |||||
| if args.model_parallel == 0: | |||||
| if args.num_classes % 16 == 0: | |||||
| args.logger.info('data parallel aleardy 16, nums: {}'.format(args.num_classes)) | |||||
| else: | |||||
| args.num_classes = (args.num_classes // 16 + 1) * 16 | |||||
| else: | |||||
| if args.num_classes % (args.world_size * 16) == 0: | |||||
| args.logger.info('model parallel aleardy 16, nums: {}'.format(args.num_classes)) | |||||
| else: | |||||
| args.num_classes = (args.num_classes // (args.world_size * 16) + 1) * args.world_size * 16 | |||||
| args.logger.info('for D, loaded, class nums: {}'.format(args.num_classes)) | |||||
| args.logger.info('steps_per_epoch:{}'.format(args.steps_per_epoch)) | |||||
| args.logger.info('img_total_num:{}'.format(args.steps_per_epoch * args.per_batch_size)) | |||||
| args.logger.info('get_backbone----in----') | |||||
| _backbone = get_backbone(args) | |||||
| args.logger.info('get_backbone----out----') | |||||
| args.logger.info('get_metric_fc----in----') | |||||
| margin_fc_1 = get_metric_fc(args) | |||||
| args.logger.info('get_metric_fc----out----') | |||||
| args.logger.info('DistributedHelper----in----') | |||||
| network_1 = DistributedHelper(_backbone, margin_fc_1) | |||||
| args.logger.info('DistributedHelper----out----') | |||||
| args.logger.info('network fp16----in----') | |||||
| if args.fp16 == 1: | |||||
| network_1.add_flags_recursive(fp16=True) | |||||
| args.logger.info('network fp16----out----') | |||||
| criterion_1 = get_loss(args) | |||||
| if args.fp16 == 1 and args.model_parallel == 0: | |||||
| criterion_1.add_flags_recursive(fp32=True) | |||||
| if os.path.isfile(args.pretrained): | |||||
| param_dict = load_checkpoint(args.pretrained) | |||||
| def load_pretrain(cfg, net): | |||||
| '''load pretrain function.''' | |||||
| if os.path.isfile(cfg.pretrained): | |||||
| param_dict = load_checkpoint(cfg.pretrained) | |||||
| param_dict_new = {} | param_dict_new = {} | ||||
| if args_opt.train_stage.lower() == 'base': | |||||
| if cfg.train_stage.lower() == 'base': | |||||
| for key, value in param_dict.items(): | for key, value in param_dict.items(): | ||||
| if key.startswith('moments.'): | if key.startswith('moments.'): | ||||
| continue | continue | ||||
| @@ -201,35 +113,169 @@ if __name__ == "__main__": | |||||
| continue | continue | ||||
| else: | else: | ||||
| param_dict_new[key[8:]] = value | param_dict_new[key[8:]] = value | ||||
| load_param_into_net(network_1, param_dict_new) | |||||
| args.logger.info('load model {} success'.format(args.pretrained)) | |||||
| load_param_into_net(net, param_dict_new) | |||||
| cfg.logger.info('load model {} success'.format(cfg.pretrained)) | |||||
| else: | else: | ||||
| init_net(args, network_1) | |||||
| if cfg.train_stage.lower() == 'beta': | |||||
| raise ValueError("Train beta mode load pretrain model fail from: {}".format(cfg.pretrained)) | |||||
| init_net(cfg, net) | |||||
| cfg.logger.info('init model success') | |||||
| return net | |||||
| def modelarts_pre_process(): | |||||
| '''modelarts pre process function.''' | |||||
| def unzip(zip_file, save_dir): | |||||
| import zipfile | |||||
| s_time = time.time() | |||||
| if not os.path.exists(os.path.join(save_dir, "face_recognition_dataset")): | |||||
| zip_isexist = zipfile.is_zipfile(zip_file) | |||||
| if zip_isexist: | |||||
| fz = zipfile.ZipFile(zip_file, 'r') | |||||
| data_num = len(fz.namelist()) | |||||
| print("Extract Start...") | |||||
| print("unzip file num: {}".format(data_num)) | |||||
| i = 0 | |||||
| for file in fz.namelist(): | |||||
| if i % int(data_num / 100) == 0: | |||||
| print("unzip percent: {}%".format(i / int(data_num / 100)), flush=True) | |||||
| i += 1 | |||||
| fz.extract(file, save_dir) | |||||
| print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), | |||||
| int(int(time.time() - s_time) % 60))) | |||||
| print("Extract Done.") | |||||
| else: | |||||
| print("This is not zip.") | |||||
| else: | |||||
| print("Zip has been extracted.") | |||||
| if config.need_modelarts_dataset_unzip: | |||||
| zip_file_1 = os.path.join(config.data_path, "face_recognition_dataset.zip") | |||||
| save_dir_1 = os.path.join(config.data_path) | |||||
| sync_lock = "/tmp/unzip_sync.lock" | |||||
| # Each server contains 8 devices as most. | |||||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||||
| print("Zip file path: ", zip_file_1) | |||||
| print("Unzip file save dir: ", save_dir_1) | |||||
| unzip(zip_file_1, save_dir_1) | |||||
| print("===Finish extract data synchronization===") | |||||
| try: | |||||
| os.mknod(sync_lock) | |||||
| except IOError: | |||||
| pass | |||||
| while True: | |||||
| if os.path.exists(sync_lock): | |||||
| break | |||||
| time.sleep(1) | |||||
| print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) | |||||
| train_net = BuildTrainNetwork(network_1, criterion_1, args) | |||||
| config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path) | |||||
| args.logger.info('args:{}'.format(args)) | |||||
| # call warmup_step should behind the args steps_per_epoch | |||||
| args.lrs = warmup_step_list(args, gamma=0.1) | |||||
| lrs_gen = list_to_gen(args.lrs) | |||||
| opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=momentum, | |||||
| weight_decay=weight_decay) | |||||
| scale_manager = DynamicLossScaleManager(init_loss_scale=args.dynamic_init_loss_scale, scale_factor=2, | |||||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||||
| def run_train(): | |||||
| '''run train function.''' | |||||
| config.local_rank = get_rank_id() | |||||
| config.world_size = get_device_num() | |||||
| log_path = os.path.join(config.ckpt_path, 'logs') | |||||
| config.logger = get_logger(log_path, config.local_rank) | |||||
| support_train_stage = ['base', 'beta'] | |||||
| if config.train_stage.lower() not in support_train_stage: | |||||
| config.logger.info('your train stage is not support.') | |||||
| raise ValueError('train stage not support.') | |||||
| if not os.path.exists(config.data_dir): | |||||
| config.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py') | |||||
| raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py') | |||||
| parallel_mode = ParallelMode.HYBRID_PARALLEL if config.is_distributed else ParallelMode.STAND_ALONE | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, | |||||
| device_num=config.world_size, gradients_mean=True) | |||||
| if config.is_distributed: | |||||
| init() | |||||
| if config.local_rank % 8 == 0: | |||||
| if not os.path.exists(config.ckpt_path): | |||||
| os.makedirs(config.ckpt_path) | |||||
| de_dataset, steps_per_epoch, num_classes = get_de_dataset(config) | |||||
| config.logger.info('de_dataset: %d', de_dataset.get_dataset_size()) | |||||
| config.steps_per_epoch = steps_per_epoch | |||||
| config.num_classes = num_classes | |||||
| config.lr_epochs = list(map(int, config.lr_epochs.split(','))) | |||||
| config.logger.info('config.num_classes: %d', config.num_classes) | |||||
| config.logger.info('config.world_size: %d', config.world_size) | |||||
| config.logger.info('config.local_rank: %d', config.local_rank) | |||||
| config.logger.info('config.lr: %f', config.lr) | |||||
| if config.nc_16 == 1: | |||||
| if config.model_parallel == 0: | |||||
| if config.num_classes % 16 == 0: | |||||
| config.logger.info('data parallel aleardy 16, nums: %d', config.num_classes) | |||||
| else: | |||||
| config.num_classes = (config.num_classes // 16 + 1) * 16 | |||||
| else: | |||||
| if config.num_classes % (config.world_size * 16) == 0: | |||||
| config.logger.info('model parallel aleardy 16, nums: %d', config.num_classes) | |||||
| else: | |||||
| config.num_classes = (config.num_classes // (config.world_size * 16) + 1) * config.world_size * 16 | |||||
| config.logger.info('for D, loaded, class nums: %d', config.num_classes) | |||||
| config.logger.info('steps_per_epoch: %d', config.steps_per_epoch) | |||||
| config.logger.info('img_total_num: %d', config.steps_per_epoch * config.per_batch_size) | |||||
| config.logger.info('get_backbone----in----') | |||||
| _backbone = get_backbone(config) | |||||
| config.logger.info('get_backbone----out----') | |||||
| config.logger.info('get_metric_fc----in----') | |||||
| margin_fc_1 = get_metric_fc(config) | |||||
| config.logger.info('get_metric_fc----out----') | |||||
| config.logger.info('DistributedHelper----in----') | |||||
| network_1 = DistributedHelper(_backbone, margin_fc_1) | |||||
| config.logger.info('DistributedHelper----out----') | |||||
| config.logger.info('network fp16----in----') | |||||
| if config.fp16 == 1: | |||||
| network_1.add_flags_recursive(fp16=True) | |||||
| config.logger.info('network fp16----out----') | |||||
| criterion_1 = get_loss(config) | |||||
| if config.fp16 == 1 and config.model_parallel == 0: | |||||
| criterion_1.add_flags_recursive(fp32=True) | |||||
| network_1 = load_pretrain(config, network_1) | |||||
| train_net = BuildTrainNetwork(network_1, criterion_1, config) | |||||
| # call warmup_step should behind the config steps_per_epoch | |||||
| config.lrs = warmup_step_list(config, gamma=0.1) | |||||
| lrs_gen = list_to_gen(config.lrs) | |||||
| opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=config.momentum, | |||||
| weight_decay=config.weight_decay) | |||||
| scale_manager = DynamicLossScaleManager(init_loss_scale=config.dynamic_init_loss_scale, scale_factor=2, | |||||
| scale_window=2000) | scale_window=2000) | ||||
| model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager) | model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager) | ||||
| save_checkpoint_steps = args.ckpt_steps | |||||
| args.logger.info('save_checkpoint_steps:{}'.format(save_checkpoint_steps)) | |||||
| if args.max_ckpts == -1: | |||||
| keep_checkpoint_max = int(args.steps_per_epoch * args.max_epoch / save_checkpoint_steps) + 5 # for more than 5 | |||||
| save_checkpoint_steps = config.ckpt_steps | |||||
| config.logger.info('save_checkpoint_steps: %d', save_checkpoint_steps) | |||||
| if config.max_ckpts == -1: | |||||
| keep_checkpoint_max = int(config.steps_per_epoch * config.max_epoch / save_checkpoint_steps) + 5 | |||||
| else: | else: | ||||
| keep_checkpoint_max = args.max_ckpts | |||||
| args.logger.info('keep_checkpoint_max:{}'.format(keep_checkpoint_max)) | |||||
| keep_checkpoint_max = config.max_ckpts | |||||
| config.logger.info('keep_checkpoint_max: %d', keep_checkpoint_max) | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max) | ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max) | ||||
| max_epoch_train = args.max_epoch | |||||
| args.logger.info('max_epoch_train:{}'.format(max_epoch_train)) | |||||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.ckpt_path, prefix='{}'.format(args.local_rank)) | |||||
| args.epoch_cnt = 0 | |||||
| progress_cb = ProgressMonitor(args) | |||||
| new_epoch_train = max_epoch_train * steps_per_epoch // args.log_interval | |||||
| model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=args.log_interval) | |||||
| config.logger.info('max_epoch_train: %d', config.max_epoch) | |||||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=config.ckpt_path, prefix='{}'.format(config.local_rank)) | |||||
| config.epoch_cnt = 0 | |||||
| progress_cb = ProgressMonitor(config) | |||||
| new_epoch_train = config.max_epoch * steps_per_epoch // config.log_interval | |||||
| model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=config.log_interval) | |||||
| if __name__ == "__main__": | |||||
| run_train() | |||||
| @@ -0,0 +1,127 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Parse arguments""" | |||||
| import os | |||||
| import ast | |||||
| import argparse | |||||
| from pprint import pprint, pformat | |||||
| import yaml | |||||
| class Config: | |||||
| """ | |||||
| Configuration namespace. Convert dictionary to members. | |||||
| """ | |||||
| def __init__(self, cfg_dict): | |||||
| for k, v in cfg_dict.items(): | |||||
| if isinstance(v, (list, tuple)): | |||||
| setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) | |||||
| else: | |||||
| setattr(self, k, Config(v) if isinstance(v, dict) else v) | |||||
| def __str__(self): | |||||
| return pformat(self.__dict__) | |||||
| def __repr__(self): | |||||
| return self.__str__() | |||||
| def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): | |||||
| """ | |||||
| Parse command line arguments to the configuration according to the default yaml. | |||||
| Args: | |||||
| parser: Parent parser. | |||||
| cfg: Base configuration. | |||||
| helper: Helper description. | |||||
| cfg_path: Path to the default yaml config. | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", | |||||
| parents=[parser]) | |||||
| helper = {} if helper is None else helper | |||||
| choices = {} if choices is None else choices | |||||
| for item in cfg: | |||||
| if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): | |||||
| help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) | |||||
| choice = choices[item] if item in choices else None | |||||
| if isinstance(cfg[item], bool): | |||||
| parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, | |||||
| help=help_description) | |||||
| else: | |||||
| parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, | |||||
| help=help_description) | |||||
| args = parser.parse_args() | |||||
| return args | |||||
| def parse_yaml(yaml_path): | |||||
| """ | |||||
| Parse the yaml config file. | |||||
| Args: | |||||
| yaml_path: Path to the yaml config. | |||||
| """ | |||||
| with open(yaml_path, 'r') as fin: | |||||
| try: | |||||
| cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) | |||||
| cfgs = [x for x in cfgs] | |||||
| if len(cfgs) == 1: | |||||
| cfg_helper = {} | |||||
| cfg = cfgs[0] | |||||
| cfg_choices = {} | |||||
| elif len(cfgs) == 2: | |||||
| cfg, cfg_helper = cfgs | |||||
| cfg_choices = {} | |||||
| elif len(cfgs) == 3: | |||||
| cfg, cfg_helper, cfg_choices = cfgs | |||||
| else: | |||||
| raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml") | |||||
| print(cfg_helper) | |||||
| except: | |||||
| raise ValueError("Failed to parse yaml") | |||||
| return cfg, cfg_helper, cfg_choices | |||||
| def merge(args, cfg): | |||||
| """ | |||||
| Merge the base config from yaml file and command line arguments. | |||||
| Args: | |||||
| args: Command line arguments. | |||||
| cfg: Base configuration. | |||||
| """ | |||||
| args_var = vars(args) | |||||
| for item in args_var: | |||||
| cfg[item] = args_var[item] | |||||
| return cfg | |||||
| def get_config(): | |||||
| """ | |||||
| Get Config according to the yaml file and cli arguments. | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="default name", add_help=False) | |||||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |||||
| parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"), | |||||
| help="Config file path") | |||||
| path_args, _ = parser.parse_known_args() | |||||
| default, helper, choices = parse_yaml(path_args.config_path) | |||||
| pprint(default) | |||||
| args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) | |||||
| final_config = merge(args, default) | |||||
| return Config(final_config) | |||||
| config = get_config() | |||||
| @@ -0,0 +1,27 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Device adapter for ModelArts""" | |||||
| from utils.config import config | |||||
| if config.enable_modelarts: | |||||
| from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||||
| else: | |||||
| from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||||
| __all__ = [ | |||||
| "get_device_id", "get_device_num", "get_rank_id", "get_job_id" | |||||
| ] | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Local adapter""" | |||||
| import os | |||||
| def get_device_id(): | |||||
| device_id = os.getenv('DEVICE_ID', '0') | |||||
| return int(device_id) | |||||
| def get_device_num(): | |||||
| device_num = os.getenv('RANK_SIZE', '1') | |||||
| return int(device_num) | |||||
| def get_rank_id(): | |||||
| global_rank_id = os.getenv('RANK_ID', '0') | |||||
| return int(global_rank_id) | |||||
| def get_job_id(): | |||||
| return "Local Job" | |||||
| @@ -0,0 +1,116 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Moxing adapter for ModelArts""" | |||||
| import os | |||||
| import functools | |||||
| from mindspore import context | |||||
| from utils.config import config | |||||
| _global_sync_count = 0 | |||||
| def get_device_id(): | |||||
| device_id = os.getenv('DEVICE_ID', '0') | |||||
| return int(device_id) | |||||
| def get_device_num(): | |||||
| device_num = os.getenv('RANK_SIZE', '1') | |||||
| return int(device_num) | |||||
| def get_rank_id(): | |||||
| global_rank_id = os.getenv('RANK_ID', '0') | |||||
| return int(global_rank_id) | |||||
| def get_job_id(): | |||||
| job_id = os.getenv('JOB_ID') | |||||
| job_id = job_id if job_id != "" else "default" | |||||
| return job_id | |||||
| def sync_data(from_path, to_path): | |||||
| """ | |||||
| Download data from remote obs to local directory if the first url is remote url and the second one is local path | |||||
| Upload data from local directory to remote obs in contrast. | |||||
| """ | |||||
| import moxing as mox | |||||
| import time | |||||
| global _global_sync_count | |||||
| sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) | |||||
| _global_sync_count += 1 | |||||
| # Each server contains 8 devices as most. | |||||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||||
| print("from path: ", from_path) | |||||
| print("to path: ", to_path) | |||||
| mox.file.copy_parallel(from_path, to_path) | |||||
| print("===finish data synchronization===") | |||||
| try: | |||||
| os.mknod(sync_lock) | |||||
| except IOError: | |||||
| pass | |||||
| print("===save flag===") | |||||
| while True: | |||||
| if os.path.exists(sync_lock): | |||||
| break | |||||
| time.sleep(1) | |||||
| print("Finish sync data from {} to {}.".format(from_path, to_path)) | |||||
| def moxing_wrapper(pre_process=None, post_process=None): | |||||
| """ | |||||
| Moxing wrapper to download dataset and upload outputs. | |||||
| """ | |||||
| def wrapper(run_func): | |||||
| @functools.wraps(run_func) | |||||
| def wrapped_func(*args, **kwargs): | |||||
| # Download data from data_url | |||||
| if config.enable_modelarts: | |||||
| if config.data_url: | |||||
| sync_data(config.data_url, config.data_path) | |||||
| print("Dataset downloaded: ", os.listdir(config.data_path)) | |||||
| if config.checkpoint_url: | |||||
| sync_data(config.checkpoint_url, config.load_path) | |||||
| print("Preload downloaded: ", os.listdir(config.load_path)) | |||||
| if config.train_url: | |||||
| sync_data(config.train_url, config.output_path) | |||||
| print("Workspace downloaded: ", os.listdir(config.output_path)) | |||||
| context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) | |||||
| config.device_num = get_device_num() | |||||
| config.device_id = get_device_id() | |||||
| if not os.path.exists(config.output_path): | |||||
| os.makedirs(config.output_path) | |||||
| if pre_process: | |||||
| pre_process() | |||||
| # Run the main function | |||||
| run_func(*args, **kwargs) | |||||
| # Upload data to train_url | |||||
| if config.enable_modelarts: | |||||
| if post_process: | |||||
| post_process() | |||||
| if config.train_url: | |||||
| print("Start to copy output directory") | |||||
| sync_data(config.output_path, config.train_url) | |||||
| return wrapped_func | |||||
| return wrapper | |||||