diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multigpu_train.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multigpu_train.sh new file mode 100644 index 0000000000..355102764a --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multigpu_train.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export RANK_SIZE=$1 +export EPOCH_SIZE=$2 +export DATASET=$3 +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +export MS_WORKER_NUM=$RANK_SIZE +export MS_SERVER_NUM=$4 +export MS_SCHED_HOST=$5 +export MS_SCHED_PORT=$6 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \ + --device_target='GPU' --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ + --vocab_cache_size=300000 >sched_$i.log 2>&1 & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \ + --device_target='GPU' --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ + --vocab_cache_size=300000 >server_$i.log 2>&1 & +done + +export MS_ROLE=MS_WORKER +rm -rf ${execute_path}/worker/ +mkdir ${execute_path}/worker/ +cd ${execute_path}/worker/ || exit +mpirun --allow-run-as-root -n $RANK_SIZE python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \ + --device_target='GPU' --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ + --vocab_cache_size=300000 --full_batch=1 --dropout_flag=1 >worker.log 2>&1 & diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multinpu_train.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multinpu_train.sh new file mode 100644 index 0000000000..fc9f9a455b --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multinpu_train.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export RANK_SIZE=$1 +export EPOCH_SIZE=$2 +export DATASET=$3 +export RANK_TABLE_FILE=$4 +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +export MS_WORKER_NUM=$RANK_SIZE +export MS_SERVER_NUM=$5 +export MS_SCHED_HOST=$6 +export MS_SCHED_PORT=$7 + +export MS_ROLE=MS_SCHED +for((i=0;i<1;i++)); +do + rm -rf ${execute_path}/sched_$i/ + mkdir ${execute_path}/sched_$i/ + cd ${execute_path}/sched_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \ + --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ + --vocab_cache_size=300000 >sched_$i.log 2>&1 & +done + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \ + --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ + --vocab_cache_size=300000 >server_$i.log 2>&1 & +done + +export MS_ROLE=MS_WORKER +for((i=0;i<$MS_WORKER_NUM;i++)); +do + rm -rf ${execute_path}/worker_$i/ + mkdir ${execute_path}/worker_$i/ + cd ${execute_path}/worker_$i/ || exit + export RANK_ID=$i + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \ + --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ + --vocab_cache_size=300000 --full_batch=1 --dropout_flag=1 >worker_$i.log 2>&1 & +done diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_standalone_train.sh b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_standalone_train.sh new file mode 100644 index 0000000000..7341cfa17b --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_standalone_train.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +execute_path=$(pwd) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export EPOCH_SIZE=$1 +export DEVICE_TARGET=$2 +export DATASET=$3 +export MS_COMM_TYPE=zmq +export MS_SCHED_NUM=1 +export MS_WORKER_NUM=1 +export MS_SERVER_NUM=$4 +export MS_SCHED_HOST=$5 +export MS_SCHED_PORT=$6 + +export MS_ROLE=MS_SCHED +rm -rf ${execute_path}/sched/ +mkdir ${execute_path}/sched/ +cd ${execute_path}/sched/ || exit +export DEVICE_ID=$i +python -s ${self_path}/../train_and_eval_parameter_server.py --epochs=$EPOCH_SIZE --device_target=$DEVICE_TARGET --data_path=$DATASET \ + --parameter_server=1 --vocab_cache_size=300000 >sched.log 2>&1 & + +export MS_ROLE=MS_PSERVER +for((i=0;i<$MS_SERVER_NUM;i++)); +do + rm -rf ${execute_path}/server_$i/ + mkdir ${execute_path}/server_$i/ + cd ${execute_path}/server_$i/ || exit + export DEVICE_ID=$i + python -s ${self_path}/../train_and_eval_parameter_server.py --epochs=$EPOCH_SIZE --device_target=$DEVICE_TARGET --data_path=$DATASET \ + --parameter_server=1 --vocab_cache_size=300000 >server_$i.log 2>&1 & +done + +export MS_ROLE=MS_WORKER +rm -rf ${execute_path}/worker/ +mkdir ${execute_path}/worker/ +cd ${execute_path}/worker/ || exit +export DEVICE_ID=$i +python -s ${self_path}/../train_and_eval_parameter_server.py --epochs=$EPOCH_SIZE --device_target=$DEVICE_TARGET --data_path=$DATASET \ + --parameter_server=1 --vocab_cache_size=300000 \ + --dropout_flag=1 >worker.log 2>&1 & diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_cache_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_cache_distribute.py new file mode 100644 index 0000000000..7002dd07b5 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_cache_distribute.py @@ -0,0 +1,141 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_multinpu.""" + + +import os +import sys +import mindspore.dataset.engine as de +from mindspore import Model, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.context import ParallelMode +from mindspore.communication.management import get_rank, get_group_size, init +from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple + +from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel +from src.callbacks import LossCallBack, EvalCallBack +from src.datasets import create_dataset, DataType +from src.metrics import AUCMetric +from src.config import WideDeepConfig + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def get_WideDeep_net(config): + """ + Get network of wide&deep model. + """ + WideDeep_net = WideDeepModel(config) + loss_net = NetWithLossClass(WideDeep_net, config) + loss_net = VirtualDatasetCellTriple(loss_net) + train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), + cache_enable=bool(config.vocab_cache_size > 0)) + eval_net = PredictWithSigmoid(WideDeep_net) + eval_net = VirtualDatasetCellTriple(eval_net) + return train_net, eval_net + + +class ModelBuilder(): + """ + ModelBuilder + """ + + def __init__(self): + pass + + def get_hook(self): + pass + + def get_train_hook(self): + hooks = [] + callback = LossCallBack() + hooks.append(callback) + if int(os.getenv('DEVICE_ID')) == 0: + pass + return hooks + + def get_net(self, config): + return get_WideDeep_net(config) + + +def train_and_eval(config): + """ + test_train_eval + """ + data_path = config.data_path + batch_size = config.batch_size + epochs = config.epochs + if config.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif config.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + dataset_type = DataType.H5 + print("epochs is {}".format(epochs)) + if config.full_batch: + context.set_auto_parallel_context(full_batch=True) + de.config.set_seed(1) + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type) + else: + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size, rank_id=get_rank(), + rank_size=get_group_size(), data_type=dataset_type) + print("ds_train.size: {}".format(ds_train.get_dataset_size())) + print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) + + net_builder = ModelBuilder() + + train_net, eval_net = net_builder.get_net(config) + train_net.set_train() + auc_metric = AUCMetric() + + model = Model(train_net, eval_network=eval_net, + metrics={"auc": auc_metric}) + + eval_callback = EvalCallBack( + model, ds_eval, auc_metric, config) + + callback = LossCallBack(config=config, per_print_times=20) + ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, + keep_checkpoint_max=5, integrated_save=False) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', + directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) + context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) + callback_list = [TimeMonitor( + ds_train.get_dataset_size()), eval_callback, callback] + callback_list.append(ckpoint_cb) + model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=True) + + +if __name__ == "__main__": + wide_deep_config = WideDeepConfig() + wide_deep_config.argparse_init() + context.set_context(mode=context.GRAPH_MODE, + device_target=wide_deep_config.device_target, save_graphs=True) + context.set_context(variable_memory_max_size="24GB") + context.set_context(enable_sparse=True) + context.set_ps_context(enable_ps=True) + init() + context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) + + context.set_auto_parallel_context( + parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True) + train_and_eval(wide_deep_config)