Browse Source

add model zoo for ps cache

tags/v1.1.0
lizhenyu 5 years ago
parent
commit
2dc1f4637c
4 changed files with 323 additions and 0 deletions
  1. +58
    -0
      model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multigpu_train.sh
  2. +68
    -0
      model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multinpu_train.sh
  3. +56
    -0
      model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_standalone_train.sh
  4. +141
    -0
      model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_cache_distribute.py

+ 58
- 0
model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multigpu_train.sh View File

@@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export RANK_SIZE=$1
export EPOCH_SIZE=$2
export DATASET=$3
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=$RANK_SIZE
export MS_SERVER_NUM=$4
export MS_SCHED_HOST=$5
export MS_SCHED_PORT=$6

export MS_ROLE=MS_SCHED
for((i=0;i<1;i++));
do
rm -rf ${execute_path}/sched_$i/
mkdir ${execute_path}/sched_$i/
cd ${execute_path}/sched_$i/ || exit
python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \
--device_target='GPU' --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=300000 >sched_$i.log 2>&1 &
done

export MS_ROLE=MS_PSERVER
for((i=0;i<$MS_SERVER_NUM;i++));
do
rm -rf ${execute_path}/server_$i/
mkdir ${execute_path}/server_$i/
cd ${execute_path}/server_$i/ || exit
python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \
--device_target='GPU' --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=300000 >server_$i.log 2>&1 &
done

export MS_ROLE=MS_WORKER
rm -rf ${execute_path}/worker/
mkdir ${execute_path}/worker/
cd ${execute_path}/worker/ || exit
mpirun --allow-run-as-root -n $RANK_SIZE python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \
--device_target='GPU' --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=300000 --full_batch=1 --dropout_flag=1 >worker.log 2>&1 &

+ 68
- 0
model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_multinpu_train.sh View File

@@ -0,0 +1,68 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export RANK_SIZE=$1
export EPOCH_SIZE=$2
export DATASET=$3
export RANK_TABLE_FILE=$4
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=$RANK_SIZE
export MS_SERVER_NUM=$5
export MS_SCHED_HOST=$6
export MS_SCHED_PORT=$7

export MS_ROLE=MS_SCHED
for((i=0;i<1;i++));
do
rm -rf ${execute_path}/sched_$i/
mkdir ${execute_path}/sched_$i/
cd ${execute_path}/sched_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \
--data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=300000 >sched_$i.log 2>&1 &
done

export MS_ROLE=MS_PSERVER
for((i=0;i<$MS_SERVER_NUM;i++));
do
rm -rf ${execute_path}/server_$i/
mkdir ${execute_path}/server_$i/
cd ${execute_path}/server_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \
--data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=300000 >server_$i.log 2>&1 &
done

export MS_ROLE=MS_WORKER
for((i=0;i<$MS_WORKER_NUM;i++));
do
rm -rf ${execute_path}/worker_$i/
mkdir ${execute_path}/worker_$i/
cd ${execute_path}/worker_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server_cache_distribute.py \
--data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=300000 --full_batch=1 --dropout_flag=1 >worker_$i.log 2>&1 &
done

+ 56
- 0
model_zoo/official/recommend/wide_and_deep/script/run_parameter_server_cache_standalone_train.sh View File

@@ -0,0 +1,56 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export EPOCH_SIZE=$1
export DEVICE_TARGET=$2
export DATASET=$3
export MS_COMM_TYPE=zmq
export MS_SCHED_NUM=1
export MS_WORKER_NUM=1
export MS_SERVER_NUM=$4
export MS_SCHED_HOST=$5
export MS_SCHED_PORT=$6

export MS_ROLE=MS_SCHED
rm -rf ${execute_path}/sched/
mkdir ${execute_path}/sched/
cd ${execute_path}/sched/ || exit
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server.py --epochs=$EPOCH_SIZE --device_target=$DEVICE_TARGET --data_path=$DATASET \
--parameter_server=1 --vocab_cache_size=300000 >sched.log 2>&1 &

export MS_ROLE=MS_PSERVER
for((i=0;i<$MS_SERVER_NUM;i++));
do
rm -rf ${execute_path}/server_$i/
mkdir ${execute_path}/server_$i/
cd ${execute_path}/server_$i/ || exit
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server.py --epochs=$EPOCH_SIZE --device_target=$DEVICE_TARGET --data_path=$DATASET \
--parameter_server=1 --vocab_cache_size=300000 >server_$i.log 2>&1 &
done

export MS_ROLE=MS_WORKER
rm -rf ${execute_path}/worker/
mkdir ${execute_path}/worker/
cd ${execute_path}/worker/ || exit
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server.py --epochs=$EPOCH_SIZE --device_target=$DEVICE_TARGET --data_path=$DATASET \
--parameter_server=1 --vocab_cache_size=300000 \
--dropout_flag=1 >worker.log 2>&1 &

+ 141
- 0
model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_cache_distribute.py View File

@@ -0,0 +1,141 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_multinpu."""


import os
import sys
import mindspore.dataset.engine as de
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple

from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import WideDeepConfig

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def get_WideDeep_net(config):
"""
Get network of wide&deep model.
"""
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server),
cache_enable=bool(config.vocab_cache_size > 0))
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net


class ModelBuilder():
"""
ModelBuilder
"""

def __init__(self):
pass

def get_hook(self):
pass

def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks

def get_net(self, config):
return get_WideDeep_net(config)


def train_and_eval(config):
"""
test_train_eval
"""
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
if config.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif config.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
print("epochs is {}".format(epochs))
if config.full_batch:
context.set_auto_parallel_context(full_batch=True)
de.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size*get_group_size(), data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size*get_group_size(), data_type=dataset_type)
else:
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

net_builder = ModelBuilder()

train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()

model = Model(train_net, eval_network=eval_net,
metrics={"auc": auc_metric})

eval_callback = EvalCallBack(
model, ds_eval, auc_metric, config)

callback = LossCallBack(config=config, per_print_times=20)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
keep_checkpoint_max=5, integrated_save=False)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig)
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
callback_list = [TimeMonitor(
ds_train.get_dataset_size()), eval_callback, callback]
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=True)


if __name__ == "__main__":
wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init()
context.set_context(mode=context.GRAPH_MODE,
device_target=wide_deep_config.device_target, save_graphs=True)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
context.set_ps_context(enable_ps=True)
init()
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))

context.set_auto_parallel_context(
parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
train_and_eval(wide_deep_config)

Loading…
Cancel
Save