Browse Source

!11320 [ModelZoo]fix bgcf train and eval device id bug

From: @zhan_ke
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
064ee0b383
4 changed files with 6 additions and 10 deletions
  1. +3
    -5
      model_zoo/official/gnn/bgcf/eval.py
  2. +1
    -2
      model_zoo/official/gnn/bgcf/scripts/run_eval_ascend.sh
  3. +1
    -2
      model_zoo/official/gnn/bgcf/scripts/run_train_ascend.sh
  4. +1
    -1
      model_zoo/official/gnn/bgcf/train.py

+ 3
- 5
model_zoo/official/gnn/bgcf/eval.py View File

@@ -15,7 +15,6 @@
"""
BGCF evaluation script.
"""
import os
import datetime
import mindspore.context as context
@@ -78,12 +77,11 @@ def evaluation():
if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False)
parser = parser_args()
os.environ['DEVICE_ID'] = parser.device
save_graphs=False,
device_id=int(parser.device))
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,


+ 1
- 2
model_zoo/official/gnn/bgcf/scripts/run_eval_ascend.sh View File

@@ -17,7 +17,6 @@
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0

if [ -d "eval" ];
@@ -31,7 +30,7 @@ cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
echo "start evaluation"

python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log &



+ 1
- 2
model_zoo/official/gnn/bgcf/scripts/run_train_ascend.sh View File

@@ -17,7 +17,6 @@
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0

if [ -d "train" ];
@@ -37,7 +36,7 @@ cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
env > env.log
echo "start training for device $DEVICE_ID"
echo "start training"

python train.py --datapath=../data_mr --ckptpath=../ckpts &> log &



+ 1
- 1
model_zoo/official/gnn/bgcf/train.py View File

@@ -105,7 +105,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=parser.device)
device_id=int(parser.device))

train_graph, _, sampled_graph_list = load_graph(parser.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,


Loading…
Cancel
Save