Browse Source

!1738 modify widedeep

Merge pull request !1738 from wukesong/widedeep
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b5957b5af6
3 changed files with 28 additions and 5 deletions
  1. +17
    -0
      model_zoo/wide_and_deep/run_multinpu_train.sh
  2. +1
    -1
      model_zoo/wide_and_deep/train_and_test.py
  3. +10
    -4
      model_zoo/wide_and_deep/train_and_test_multinpu.py

+ 17
- 0
model_zoo/wide_and_deep/run_multinpu_train.sh View File

@@ -0,0 +1,17 @@
#!/bin/bash
# bash run_multinpu_train.sh
execute_path=$(pwd)

export RANK_TABLE_FILE=${execute_path}/rank_table_8p.json
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=${execute_path}/rank_table_8p.json

for((i=0;i<=7;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
pytest -s ${execute_path}/train_and_test_multinpu.py >train_deep$i.log 2>&1 &
done

+ 1
- 1
model_zoo/wide_and_deep/train_and_test.py View File

@@ -82,7 +82,7 @@ def test_train_eval(config):
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)

out = model.eval(ds_eval)


+ 10
- 4
model_zoo/wide_and_deep/train_and_test_multinpu.py View File

@@ -30,7 +30,7 @@ from src.metrics import AUCMetric
from src.config import WideDeepConfig

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=GRAPH_MODE, device_target="Davinci", save_graph=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()

@@ -71,8 +71,8 @@ def test_train_eval():
test_train_eval
"""
np.random.seed(1000)
config = WideDeepConfig
data_path = Config.data_path
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
@@ -94,8 +94,14 @@ def test_train_eval():
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)
print("=====" * 5 + "model.eval() initialized: {}".format(out))
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])


if __name__ == "__main__":
test_train_eval()

Loading…
Cancel
Save