| @@ -24,16 +24,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.deepfm import ModelBuilder, AUCMetric | |||
| from src.config import DataConfig, ModelConfig, TrainConfig | |||
| from src.dataset import create_dataset | |||
| from src.dataset import create_dataset, DataType | |||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||
| parser = argparse.ArgumentParser(description='CTR Prediction') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU') | |||
| args_opt, _ = parser.parse_known_args() | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | |||
| def add_write(file_path, print_str): | |||
| @@ -47,7 +47,8 @@ if __name__ == '__main__': | |||
| train_config = TrainConfig() | |||
| ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, | |||
| epochs=1, batch_size=train_config.batch_size) | |||
| epochs=1, batch_size=train_config.batch_size, | |||
| data_type=DataType(data_config.data_format)) | |||
| model_builder = ModelBuilder(ModelConfig, TrainConfig) | |||
| train_net, eval_net = model_builder.get_train_eval_net() | |||
| train_net.set_train() | |||
| @@ -0,0 +1,38 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "Please run the script as: " | |||
| echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH" | |||
| echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path" | |||
| echo "After running the script, the network runs in the background, The log will be generated in log/output.log" | |||
| export RANK_SIZE=$1 | |||
| DATA_URL=$2 | |||
| rm -rf log | |||
| mkdir ./log | |||
| cp *.py ./log | |||
| cp -r src ./log | |||
| cd ./log || exit | |||
| env > env.log | |||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||
| python -u train.py \ | |||
| --dataset_path=$DATA_URL \ | |||
| --ckpt_path="checkpoint" \ | |||
| --eval_file_name='auc.log' \ | |||
| --loss_file_name='loss.log' \ | |||
| --device_target='GPU' \ | |||
| --do_eval=True > output.log 2>&1 & | |||
| @@ -14,13 +14,14 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "Please run the script as: " | |||
| echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH" | |||
| echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path" | |||
| echo "sh scripts/run_eval.sh DEVICE_ID DEVICE_TARGET DATASET_PATH CHECKPOINT_PATH" | |||
| echo "for example: sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path" | |||
| echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log" | |||
| export DEVICE_ID=$1 | |||
| DATA_URL=$2 | |||
| CHECKPOINT_PATH=$3 | |||
| DEVICE_TARGET=$2 | |||
| DATA_URL=$3 | |||
| CHECKPOINT_PATH=$4 | |||
| mkdir -p ms_log | |||
| CUR_DIR=`pwd` | |||
| @@ -29,4 +30,5 @@ export GLOG_logtostderr=0 | |||
| python -u eval.py \ | |||
| --dataset_path=$DATA_URL \ | |||
| --checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 & | |||
| --checkpoint_path=$CHECKPOINT_PATH \ | |||
| --device_target=$DEVICE_TARGET > ms_log/eval_output.log 2>&1 & | |||
| @@ -14,12 +14,13 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "Please run the script as: " | |||
| echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH" | |||
| echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path" | |||
| echo "sh scripts/run_standalone_train.sh DEVICE_ID DEVICE_TARGET DATASET_PATH" | |||
| echo "for example: sh scripts/run_standalone_train.sh 0 GPU /dataset_path" | |||
| echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log" | |||
| export DEVICE_ID=$1 | |||
| DATA_URL=$2 | |||
| DEVICE_TARGET=$2 | |||
| DATA_URL=$3 | |||
| mkdir -p ms_log | |||
| CUR_DIR=`pwd` | |||
| @@ -31,4 +32,5 @@ python -u train.py \ | |||
| --ckpt_path="checkpoint" \ | |||
| --eval_file_name='auc.log' \ | |||
| --loss_file_name='loss.log' \ | |||
| --device_target=$DEVICE_TARGET \ | |||
| --do_eval=True > ms_log/output.log 2>&1 & | |||
| @@ -16,11 +16,14 @@ | |||
| import os | |||
| import sys | |||
| import argparse | |||
| import random | |||
| import numpy as np | |||
| from mindspore import context, ParallelMode | |||
| from mindspore.communication.management import init | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | |||
| import mindspore.dataset.engine as de | |||
| from src.deepfm import ModelBuilder, AUCMetric | |||
| from src.config import DataConfig, ModelConfig, TrainConfig | |||
| @@ -34,24 +37,41 @@ parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path | |||
| parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path') | |||
| parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path') | |||
| parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU') | |||
| args_opt, _ = parser.parse_known_args() | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) | |||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| de.config.set_seed(1) | |||
| if __name__ == '__main__': | |||
| data_config = DataConfig() | |||
| model_config = ModelConfig() | |||
| train_config = TrainConfig() | |||
| rank_size = int(os.environ.get("RANK_SIZE", 1)) | |||
| if rank_size > 1: | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) | |||
| init() | |||
| rank_id = int(os.environ.get('RANK_ID')) | |||
| if args_opt.device_target == "Ascend": | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) | |||
| init() | |||
| rank_id = int(os.environ.get('RANK_ID')) | |||
| elif args_opt.device_target == "GPU": | |||
| init("nccl") | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| mirror_mean=True) | |||
| rank_id = get_rank() | |||
| else: | |||
| print("Unsupported device_target ", args_opt.device_target) | |||
| exit() | |||
| else: | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | |||
| rank_size = None | |||
| rank_id = None | |||
| @@ -73,6 +93,8 @@ if __name__ == '__main__': | |||
| callback_list = [time_callback, loss_callback] | |||
| if train_config.save_checkpoint: | |||
| if rank_size: | |||
| train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank()) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps, | |||
| keep_checkpoint_max=train_config.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix, | |||