diff --git a/model_zoo/official/recommend/ncf/README.md b/model_zoo/official/recommend/ncf/README.md index 913c6d418c..80def6f436 100644 --- a/model_zoo/official/recommend/ncf/README.md +++ b/model_zoo/official/recommend/ncf/README.md @@ -100,6 +100,33 @@ sh scripts/run_train.sh rank_table.json sh run_eval.sh ``` +If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows: + +```python +# run distributed training on modelarts example +# (1) First, Perform a or b. +# a. Set "enable_modelarts=True" on default_config.yaml file. +# Set other parameters on default_config.yaml file you need. +# b. Add "enable_modelarts=True" on the website UI interface. +# Add other parameters on the website UI interface. +# (2) Set the code directory to "/path/ncf" on the website UI interface. +# (3) Set the startup file to "train.py" on the website UI interface. +# (4) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (5) Create your job. + +# run evaluation on modelarts example +# (1) Copy or upload your trained model to S3 bucket. +# (2) Perform a or b. +# a. Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on default_config.yaml file. +# Set "checkpoint_url=/The path of checkpoint in S3/" on default_config.yaml file. +# b. Add "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on the website UI interface. +# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface. +# (3) Set the code directory to "/path/ncf" on the website UI interface. +# (4) Set the startup file to "eval.py" on the website UI interface. +# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface. +# (6) Create your job. +``` + # [Script Description](#contents) ## [Script and Sample Code](#contents) @@ -108,6 +135,9 @@ sh run_eval.sh ├── ModelZoo_NCF_ME ├── README.md // descriptions about NCF ├── scripts + │ ├──ascend_distributed_launcher + │ ├──__init__.py // init file + │ ├──get_distribute_pretrain_cmd.py // create distribute shell script │ ├──run_train.sh // shell script for train │ ├──run_distribute_train.sh // shell script for distribute train │ ├──run_eval.sh // shell script for evaluation @@ -116,15 +146,19 @@ sh run_eval.sh ├── src │ ├──dataset.py // creating dataset │ ├──ncf.py // ncf architecture - │ ├──config.py // parameter configuration + │ ├──config.py // parameter analysis + │ ├──device_adapter.py // device adapter + │ ├──local_adapter.py // local adapter + │ ├──moxing_adapter.py // moxing adapter │ ├──movielens.py // data download file │ ├──callbacks.py // model loss and eval callback file │ ├──constants.py // the constants of model │ ├──export.py // export checkpoint files into geir/onnx │ ├──metrics.py // the file for auc compute │ ├──stat_utils.py // the file for data process functions + ├── default_config.yaml // parameter configuration ├── train.py // training script - ├── eval.py // evaluation script + ├── eval.py // evaluation script ``` ## [Script Parameters](#contents) @@ -144,7 +178,6 @@ Parameters for both training and evaluation can be set in config.py. * `--num_factors`:The Embedding size of MF model. * `--output_path`:The location of the output file. * `--eval_file_name` : Eval output file. - * `--loss_file_name` : Loss output file. ``` ## [Training Process](#contents) diff --git a/model_zoo/official/recommend/ncf/default_config.yaml b/model_zoo/official/recommend/ncf/default_config.yaml new file mode 100644 index 0000000000..84dcc0ede6 --- /dev/null +++ b/model_zoo/official/recommend/ncf/default_config.yaml @@ -0,0 +1,54 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +device_target: "Ascend" +enable_profiling: False + +# ============================================================================== +# Training options +dataset: "ml-1m" +train_epochs: 14 +batch_size: 256 +eval_batch_size: 160000 +num_neg: 4 +layers: [64, 32, 16] +num_factors: 16 +checkpoint_path: "./checkpoint/" + +# Eval options +eval_file_name: "eval.log" +checkpoint_file_path: "./checkpoint/NCF-14_19418.ckpt" + +# Export options +device_id: 0 +ckpt_file: "" +file_name: "" +file_format: "" + +--- + +# Help description for each configuration +enable_modelarts: "Whether training on modelarts, default: False" +data_url: "Url for modelarts" +train_url: "Url for modelarts" +data_path: "The location of the input data." +output_path: "The location of the output file." +device_target: 'Target device type' +enable_profiling: 'Whether enable profiling while training, default: False' +dataset: "Dataset to be trained and evaluated, choice: ['ml-1m', 'ml-20m']" +train_epochs: "The number of epochs used to train." +batch_size: "Batch size for training and evaluation" +eval_batch_size: "The batch size used for evaluation." +num_neg: "The Number of negative instances to pair with a positive instance." +layers: "The sizes of hidden layers for MLP" +num_factors: "The Embedding size of MF model." +checkpoint_path: "The location of the checkpoint file." +eval_file_name: "Eval output file." +checkpoint_file_path: "The location of the checkpoint file." \ No newline at end of file diff --git a/model_zoo/official/recommend/ncf/eval.py b/model_zoo/official/recommend/ncf/eval.py index 8f60b13995..be1b5f930f 100644 --- a/model_zoo/official/recommend/ncf/eval.py +++ b/model_zoo/official/recommend/ncf/eval.py @@ -15,7 +15,6 @@ """Using for eval the model checkpoint""" import os -import argparse from absl import logging from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -26,31 +25,31 @@ from src.dataset import create_dataset from src.metrics import NCFMetric from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid -from src.config import cfg -logging.set_verbosity(logging.INFO) - +from utils.config import config +from utils.moxing_adapter import moxing_wrapper +from utils.device_adapter import get_device_id -parser = argparse.ArgumentParser(description='NCF') -parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data. -parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"] -parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file. -parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file. -parser.add_argument("--checkpoint_file_path", type=str, default="./checkpoint/NCF-14_19418.ckpt") # The location of the checkpoint file. -args, _ = parser.parse_known_args() +logging.set_verbosity(logging.INFO) -def test_eval(): +@moxing_wrapper() +def run_eval(): """eval method""" - if not os.path.exists(args.output_path): - os.makedirs(args.output_path) + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + context.set_context(mode=context.GRAPH_MODE, + device_target="Davinci", + save_graphs=False, + device_id=get_device_id()) - layers = cfg.layers - num_factors = cfg.num_factors + layers = config.layers + num_factors = config.num_factors topk = rconst.TOP_K num_eval_neg = rconst.NUM_EVAL_NEGATIVES - ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=args.data_path, - dataset=args.dataset, train_epochs=0, - eval_batch_size=cfg.eval_batch_size) + ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=config.data_path, + dataset=config.dataset, train_epochs=0, + eval_batch_size=config.eval_batch_size) print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) ncf_net = NCFModel(num_users=num_eval_users, @@ -60,7 +59,7 @@ def test_eval(): mf_regularization=0, mlp_reg_layers=[0.0, 0.0, 0.0, 0.0], mf_dim=16) - param_dict = load_checkpoint(args.checkpoint_file_path) + param_dict = load_checkpoint(config.checkpoint_file_path) load_param_into_net(ncf_net, param_dict) loss_net = NetWithLossClass(ncf_net) @@ -73,18 +72,12 @@ def test_eval(): ncf_metric.clear() out = model.eval(ds_eval) - eval_file_path = os.path.join(args.output_path, args.eval_file_name) + eval_file_path = os.path.join(config.output_path, config.eval_file_name) eval_file = open(eval_file_path, "a+") eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1])) eval_file.close() print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1])) - + print("=" * 100 + "Eval Finish!" + "=" * 100) if __name__ == '__main__': - devid = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, - device_target="Davinci", - save_graphs=True, - device_id=devid) - - test_eval() + run_eval() diff --git a/model_zoo/official/recommend/ncf/export.py b/model_zoo/official/recommend/ncf/export.py index 2e7b4f289d..13d7496bf1 100644 --- a/model_zoo/official/recommend/ncf/export.py +++ b/model_zoo/official/recommend/ncf/export.py @@ -13,37 +13,26 @@ # limitations under the License. # ============================================================================ """ncf export file""" -import argparse import numpy as np from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export import src.constants as rconst -from src.config import cfg +from utils.config import config from ncf import NCFModel, PredictWithSigmoid -parser = argparse.ArgumentParser(description='ncf export') -parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") -parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"], help="Dataset.") -parser.add_argument("--file_name", type=str, default="ncf", help="output file name.") -parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') -parser.add_argument("--device_target", type=str, default="Ascend", - choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") -args = parser.parse_args() - -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) -if args.device_target == "Ascend": - context.set_context(device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) +if config.device_target == "Ascend": + context.set_context(device_id=config.device_id) if __name__ == "__main__": topk = rconst.TOP_K num_eval_neg = rconst.NUM_EVAL_NEGATIVES - if args.dataset == "ml-1m": + if config.dataset == "ml-1m": num_eval_users = 6040 num_eval_items = 3706 - elif args.dataset == "ml-20m": + elif config.dataset == "ml-20m": num_eval_users = 138493 num_eval_items = 26744 else: @@ -51,20 +40,20 @@ if __name__ == "__main__": ncf_net = NCFModel(num_users=num_eval_users, num_items=num_eval_items, - num_factors=cfg.num_factors, - model_layers=cfg.layers, + num_factors=config.num_factors, + model_layers=config.layers, mf_regularization=0, mlp_reg_layers=[0.0, 0.0, 0.0, 0.0], mf_dim=16) - param_dict = load_checkpoint(args.ckpt_file) + param_dict = load_checkpoint(config.ckpt_file) load_param_into_net(ncf_net, param_dict) network = PredictWithSigmoid(ncf_net, topk, num_eval_neg) - users = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.int32)) - items = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.int32)) - masks = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.float32)) + users = Tensor(np.zeros([config.eval_batch_size, 1]).astype(np.int32)) + items = Tensor(np.zeros([config.eval_batch_size, 1]).astype(np.int32)) + masks = Tensor(np.zeros([config.eval_batch_size, 1]).astype(np.float32)) input_data = [users, items, masks] - export(network, *input_data, file_name=args.file_name, file_format=args.file_format) + export(network, *input_data, file_name=config.file_name, file_format=config.file_format) diff --git a/model_zoo/official/recommend/ncf/scripts/ascend_distributed_launcher/__init__.py b/model_zoo/official/recommend/ncf/scripts/ascend_distributed_launcher/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/recommend/ncf/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py b/model_zoo/official/recommend/ncf/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py new file mode 100644 index 0000000000..a94cbd0af3 --- /dev/null +++ b/model_zoo/official/recommend/ncf/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py @@ -0,0 +1,188 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""distribute running script""" +import os +import json +import multiprocessing +from argparse import ArgumentParser + + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="Distributed training scripts generator for MindSpore") + + parser.add_argument("--run_script_path", type=str, default="", + help="Run script path, it is better to use absolute path") + parser.add_argument("--args", type=str, default="", + help="Other arguments which will be passed to main program directly") + parser.add_argument("--hccl_config_dir", type=str, default="", + help="Hccl config path, it is better to use absolute path") + parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh", + help="Path of the generated cmd file.") + parser.add_argument("--hccl_time_out", type=int, default=120, + help="Seconds to determine the hccl time out," + "default: 120, which is the same as hccl default config") + parser.add_argument("--cpu_bind", action="store_true", default=False, + help="Bind cpu cores or not") + + args = parser.parse_args() + return args + + +def append_cmd(cmd, s): + cmd += s + cmd += "\n" + return cmd + + +def append_cmd_env(cmd, key, value): + return append_cmd(cmd, "export " + str(key) + "=" + str(value)) + + +def set_envs(cmd, logic_id, rank_id): + """ + Set environment variables. + """ + cmd = append_cmd_env(cmd, "DEVICE_ID", str(logic_id)) + cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id)) + return cmd + + +def make_dirs(cmd, logic_id): + """ + Make directories and change path. + """ + cmd = append_cmd(cmd, "rm -rf LOG" + str(logic_id)) + cmd = append_cmd(cmd, "mkdir ./LOG" + str(logic_id)) + cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(logic_id) + "/ms_log") + cmd = append_cmd(cmd, "env > ./LOG" + str(logic_id) + "/env.log") + cur_dir = os.getcwd() + cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(logic_id) + "/ms_log") + cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0") + cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(logic_id)) + return cmd + + +def print_info(rank_id, device_id, logic_id, cmdopt, cur_dir): + """ + Print some information about scripts. + """ + print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":") + print("rank_id:", rank_id) + print("device_id:", device_id) + print("logic_id", logic_id) + print("core_nums:", cmdopt) + print("log_file_dir: " + cur_dir + "/LOG" + str(logic_id) + "/pretraining_log.txt") + +def distribute_run(): + """ + distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated + based on the device_num set in hccl config file, You don not need to specify that. + """ + cmd = "" + print("start", __file__) + args = parse_args() + + run_script = args.run_script_path + + print("hccl_config_dir:", args.hccl_config_dir) + print("hccl_time_out:", args.hccl_time_out) + cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out) + cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir) + + cores = multiprocessing.cpu_count() + print("the number of logical core:", cores) + + # get device_ips + device_ips = {} + physic_logic_ids = {} + with open('/etc/hccn.conf', 'r') as fin: + for hccn_item in fin.readlines(): + if hccn_item.strip().startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip.strip() + + if not device_ips: + raise ValueError("There is no address in /etc/hccn.conf") + + for logic_id, device_id in enumerate(sorted(device_ips.keys())): + physic_logic_ids[device_id] = logic_id + + with open(args.hccl_config_dir, "r", encoding="utf-8") as fin: + hccl_config = json.loads(fin.read()) + rank_size = 0 + for server in hccl_config["server_list"]: + rank_size += len(server["device"]) + if server["device"][0]["device_ip"] in device_ips.values(): + this_server = server + + cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size)) + print("total rank size:", rank_size) + print("this server rank size:", len(this_server["device"])) + avg_core_per_rank = int(int(cores) / len(this_server["device"])) + core_gap = avg_core_per_rank - 1 + print("avg_core_per_rank:", avg_core_per_rank) + + count = 0 + for instance in this_server["device"]: + # device_id is the physical id, we use logic id to specific the selected device. + # While running on a server with 8 pcs, the logic ids are equal to the device ids. + device_id = instance["device_id"] + rank_id = instance["rank_id"] + logic_id = physic_logic_ids[device_id] + start = count * int(avg_core_per_rank) + count += 1 + end = start + core_gap + cmdopt = str(start) + "-" + str(end) + cur_dir = os.getcwd() + + cmd = set_envs(cmd, logic_id, rank_id) + cmd = make_dirs(cmd, logic_id) + + print_info(rank_id=rank_id, device_id=device_id, logic_id=logic_id, cmdopt=cmdopt, cur_dir=cur_dir) + + if args.cpu_bind: + run_cmd = 'taskset -c ' + cmdopt + ' ' + else: + run_cmd = "" + run_cmd += 'nohup python ' + run_script + " " + + run_cmd += " " + ' '.join([str(x) for x in args.args.split(' ')[1:]]) + run_cmd += ' >./log.txt 2>&1 &' + + cmd = append_cmd(cmd, run_cmd) + cmd = append_cmd(cmd, "cd -") + cmd = append_cmd(cmd, "echo \"run with" + + " rank_id=" + str(rank_id) + + " device_id=" + str(device_id) + + " logic_id=" + str(logic_id) + "\"") + cmd += "\n" + + with open(args.cmd_file, "w") as f: + f.write(cmd) + +if __name__ == "__main__": + distribute_run() diff --git a/model_zoo/official/recommend/ncf/scripts/run_distribute_train.sh b/model_zoo/official/recommend/ncf/scripts/run_distribute_train.sh index ec0a7f6718..ed6ad4df07 100644 --- a/model_zoo/official/recommend/ncf/scripts/run_distribute_train.sh +++ b/model_zoo/official/recommend/ncf/scripts/run_distribute_train.sh @@ -13,35 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -echo "Please run the script as: " -echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH RANK_TABLE_FILE" -echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json" -current_exec_path=$(pwd) -echo ${current_exec_path} +if [ $# -lt 1 ]; then + echo "==============================================================================================================" + echo "Please run the script as: " + echo "bash run_local_train.sh RANK_TABLE_FILE [OTHER_ARGS]" + echo "OTHER_ARGS will be passed to the training scripts directly," + echo "for example: bash run_local_train.sh /path/hccl.json /dataset_path" + echo "It is better to use absolute path." + echo "==============================================================================================================" + exit 1 +fi -export RANK_SIZE=$1 -data_path=$2 -export RANK_TABLE_FILE=$3 +BASE_PATH=$(cd "`dirname $0`" || exit; pwd) -for((i=0;i<=RANK_SIZE;i++)); -do - rm ${current_exec_path}/device_$i/ -rf - mkdir ${current_exec_path}/device_$i - cd ${current_exec_path}/device_$i || exit - export RANK_ID=$i - export DEVICE_ID=$i - python -u ${current_exec_path}/train.py \ - --data_path $data_path \ - --dataset 'ml-1m' \ - --train_epochs 50 \ - --output_path './output/' \ - --eval_file_name 'eval.log' \ - --loss_file_name 'loss.log' \ - --checkpoint_path './checkpoint/' \ - --device_target="Ascend" \ - --device_id=$i \ - --is_distributed=1 \ - >log_$i.log 2>&1 & -done +python3 ${BASE_PATH}/ascend_distributed_launcher/get_distribute_pretrain_cmd.py \ + --run_script_path=${BASE_PATH}/../train.py \ + --hccl_config_dir=$1 \ + --hccl_time_out=600 \ + --args=" --data_path=$2 \ +--dataset='ml-1m' \ +--train_epochs=50 \ +--output_path='./output/' \ +--eval_file_name='eval.log' \ +--checkpoint_path='./checkpoint/' \ +--device_target='Ascend'" \ + --cmd_file=distributed_cmd.sh +bash distributed_cmd.sh diff --git a/model_zoo/official/recommend/ncf/scripts/run_train.sh b/model_zoo/official/recommend/ncf/scripts/run_train.sh index fabe47e406..9255209423 100644 --- a/model_zoo/official/recommend/ncf/scripts/run_train.sh +++ b/model_zoo/official/recommend/ncf/scripts/run_train.sh @@ -19,4 +19,4 @@ echo "for example: sh scripts/run_train.sh /dataset_path /ncf.ckpt" data_path=$1 ckpt_file=$2 -python ./train.py --data_path $data_path --dataset 'ml-1m' --train_epochs 20 --batch_size 256 --output_path './output/' --loss_file_name 'loss.log' --checkpoint_path $ckpt_file +python ./train.py --data_path $data_path --dataset 'ml-1m' --train_epochs 20 --batch_size 256 --output_path './output/' --checkpoint_path $ckpt_file diff --git a/model_zoo/official/recommend/ncf/src/config.py b/model_zoo/official/recommend/ncf/src/config.py deleted file mode 100644 index 86a001a1f7..0000000000 --- a/model_zoo/official/recommend/ncf/src/config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config setting, will be used in main.py -""" -from easydict import EasyDict as edict - - -cfg = edict({ - 'dataset': 'ml-1m', # Dataset to be trained and evaluated, choice: ["ml-1m", "ml-20m"] - - 'data_dir': '../dataset', # The location of the input data. - - 'train_epochs': 14, # The number of epochs used to train. - - 'batch_size': 256, # Batch size for training and evaluation - - 'eval_batch_size': 160000, # The batch size used for evaluation. - - 'num_neg': 4, # The Number of negative instances to pair with a positive instance. - - 'layers': [64, 32, 16], # The sizes of hidden layers for MLP - - 'num_factors': 16 # The Embedding size of MF model. - - }) diff --git a/model_zoo/official/recommend/ncf/train.py b/model_zoo/official/recommend/ncf/train.py index dab7afbe33..a18eb27024 100644 --- a/model_zoo/official/recommend/ncf/train.py +++ b/model_zoo/official/recommend/ncf/train.py @@ -14,71 +14,61 @@ # ============================================================================ """Training entry file""" import os - -import argparse from absl import logging from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore import context, Model from mindspore.context import ParallelMode -from mindspore.communication.management import get_rank, get_group_size, init +from mindspore.communication.management import init from mindspore.common import set_seed from src.dataset import create_dataset from src.ncf import NCFModel, NetWithLossClass, TrainStepWrap -from config import cfg +from utils.moxing_adapter import moxing_wrapper +from utils.config import config +from utils.device_adapter import get_device_id, get_device_num, get_rank_id, get_job_id set_seed(1) logging.set_verbosity(logging.INFO) -parser = argparse.ArgumentParser(description='NCF') -parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data. -parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"] -parser.add_argument("--train_epochs", type=int, default=14) # The number of epochs used to train. -parser.add_argument("--batch_size", type=int, default=256) # Batch size for training and evaluation -parser.add_argument("--num_neg", type=int, default=4) # The Number of negative instances to pair with a positive instance. -parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file. -parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file. -parser.add_argument("--checkpoint_path", type=str, default="./checkpoint/") # The location of the checkpoint file. -parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], - help='device where the code will be implemented. (Default: Ascend)') -parser.add_argument('--device_id', type=int, default=1, help='device id of GPU or Ascend. (Default: None)') -parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') -parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') -parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') -args = parser.parse_args() - -def test_train(): +def modelarts_pre_process(): + config.checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path) + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_train(): """train entry method""" - if args.is_distributed: - if args.device_target == "Ascend": - init() - context.set_context(device_id=args.device_id) - elif args.device_target == "GPU": - init() + print(config) + print("device id: ", get_device_id()) + print("device num: ", get_device_num()) + print("rank id: ", get_rank_id()) + print("job id: ", get_job_id()) + + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) - args.rank = get_rank() - args.group_size = get_group_size() - device_num = args.group_size + config.is_distributed = bool(get_device_num() > 1) + if config.is_distributed: + config.group_size = get_device_num() context.reset_auto_parallel_context() - context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + context.set_auto_parallel_context(device_num=config.group_size, parallel_mode=ParallelMode.DATA_PARALLEL, parameter_broadcast=True, gradients_mean=True) - else: - context.set_context(device_id=args.device_id) - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - if not os.path.exists(args.output_path): - os.makedirs(args.output_path) + if config.device_target == "Ascend": + context.set_context(device_id=get_device_id()) + init() + elif config.device_target == "GPU": + init() + else: + context.set_context(device_id=get_device_id()) - layers = cfg.layers - num_factors = cfg.num_factors - epochs = args.train_epochs + layers = config.layers + num_factors = config.num_factors + epochs = config.train_epochs - ds_train, num_train_users, num_train_items = create_dataset(test_train=True, data_dir=args.data_path, - dataset=args.dataset, train_epochs=1, - batch_size=args.batch_size, num_neg=args.num_neg) + ds_train, num_train_users, num_train_items = create_dataset(test_train=True, data_dir=config.data_path, + dataset=config.dataset, train_epochs=1, + batch_size=config.batch_size, num_neg=config.num_neg) print("ds_train.size: {}".format(ds_train.get_dataset_size())) ncf_net = NCFModel(num_users=num_train_users, @@ -95,14 +85,14 @@ def test_train(): model = Model(train_net) callback = LossMonitor(per_print_times=ds_train.get_dataset_size()) - ckpt_config = CheckpointConfig(save_checkpoint_steps=(4970845+args.batch_size-1)//(args.batch_size), + ckpt_config = CheckpointConfig(save_checkpoint_steps=(4970845+config.batch_size-1)//(config.batch_size), keep_checkpoint_max=100) - ckpoint_cb = ModelCheckpoint(prefix='NCF', directory=args.checkpoint_path, config=ckpt_config) + ckpoint_cb = ModelCheckpoint(prefix='NCF', directory=config.checkpoint_path, config=ckpt_config) model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb], dataset_sink_mode=True) - + print("="*100 + "Training Finish!" + "="*100) if __name__ == '__main__': - test_train() + run_train() diff --git a/model_zoo/official/recommend/ncf/utils/__init__.py b/model_zoo/official/recommend/ncf/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/recommend/ncf/utils/config.py b/model_zoo/official/recommend/ncf/utils/config.py new file mode 100644 index 0000000000..2c191e9f74 --- /dev/null +++ b/model_zoo/official/recommend/ncf/utils/config.py @@ -0,0 +1,127 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + cfg_choices = {} + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + cfg_choices = {} + elif len(cfgs) == 3: + cfg, cfg_helper, cfg_choices = cfgs + else: + raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml") + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper, cfg_choices + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper, choices = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) + final_config = merge(args, default) + return Config(final_config) + +config = get_config() diff --git a/model_zoo/official/recommend/ncf/utils/device_adapter.py b/model_zoo/official/recommend/ncf/utils/device_adapter.py new file mode 100644 index 0000000000..92439de46b --- /dev/null +++ b/model_zoo/official/recommend/ncf/utils/device_adapter.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from utils.config import config + +if config.enable_modelarts: + from utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/model_zoo/official/recommend/ncf/utils/local_adapter.py b/model_zoo/official/recommend/ncf/utils/local_adapter.py new file mode 100644 index 0000000000..769fa6dc78 --- /dev/null +++ b/model_zoo/official/recommend/ncf/utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/model_zoo/official/recommend/ncf/utils/moxing_adapter.py b/model_zoo/official/recommend/ncf/utils/moxing_adapter.py new file mode 100644 index 0000000000..6e7f75a413 --- /dev/null +++ b/model_zoo/official/recommend/ncf/utils/moxing_adapter.py @@ -0,0 +1,122 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from mindspore.profiler import Profiler +from utils.config import config + +_global_sync_count = 0 + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + if config.enable_profiling: + profiler = Profiler() + + run_func(*args, **kwargs) + + if config.enable_profiling: + profiler.analyse() + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper