From: @zhan_ke Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34tags/v1.1.1
| @@ -15,7 +15,6 @@ | |||||
| """ | """ | ||||
| BGCF evaluation script. | BGCF evaluation script. | ||||
| """ | """ | ||||
| import os | |||||
| import datetime | import datetime | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| @@ -78,12 +77,11 @@ def evaluation(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| parser = parser_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="Ascend", | device_target="Ascend", | ||||
| save_graphs=False) | |||||
| parser = parser_args() | |||||
| os.environ['DEVICE_ID'] = parser.device | |||||
| save_graphs=False, | |||||
| device_id=int(parser.device)) | |||||
| train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath) | train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath) | ||||
| test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs, | test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs, | ||||
| @@ -17,7 +17,6 @@ | |||||
| ulimit -u unlimited | ulimit -u unlimited | ||||
| export DEVICE_NUM=1 | export DEVICE_NUM=1 | ||||
| export RANK_SIZE=$DEVICE_NUM | export RANK_SIZE=$DEVICE_NUM | ||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | export RANK_ID=0 | ||||
| if [ -d "eval" ]; | if [ -d "eval" ]; | ||||
| @@ -31,7 +30,7 @@ cp *.sh ./eval | |||||
| cp -r ../src ./eval | cp -r ../src ./eval | ||||
| cd ./eval || exit | cd ./eval || exit | ||||
| env > env.log | env > env.log | ||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| echo "start evaluation" | |||||
| python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log & | python eval.py --datapath=../data_mr --ckptpath=../ckpts &> log & | ||||
| @@ -17,7 +17,6 @@ | |||||
| ulimit -u unlimited | ulimit -u unlimited | ||||
| export DEVICE_NUM=1 | export DEVICE_NUM=1 | ||||
| export RANK_SIZE=$DEVICE_NUM | export RANK_SIZE=$DEVICE_NUM | ||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | export RANK_ID=0 | ||||
| if [ -d "train" ]; | if [ -d "train" ]; | ||||
| @@ -37,7 +36,7 @@ cp *.sh ./train | |||||
| cp -r ../src ./train | cp -r ../src ./train | ||||
| cd ./train || exit | cd ./train || exit | ||||
| env > env.log | env > env.log | ||||
| echo "start training for device $DEVICE_ID" | |||||
| echo "start training" | |||||
| python train.py --datapath=../data_mr --ckptpath=../ckpts &> log & | python train.py --datapath=../data_mr --ckptpath=../ckpts &> log & | ||||
| @@ -105,7 +105,7 @@ if __name__ == "__main__": | |||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="Ascend", | device_target="Ascend", | ||||
| save_graphs=False, | save_graphs=False, | ||||
| device_id=parser.device) | |||||
| device_id=int(parser.device)) | |||||
| train_graph, _, sampled_graph_list = load_graph(parser.datapath) | train_graph, _, sampled_graph_list = load_graph(parser.datapath) | ||||
| train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, | train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, | ||||