From 432984d86fee85a8a01d039ef563fcaa71f55e88 Mon Sep 17 00:00:00 2001 From: zhanke Date: Thu, 14 Jan 2021 10:44:45 +0800 Subject: [PATCH] fix device id bug --- model_zoo/official/gnn/bgcf/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/gnn/bgcf/train.py b/model_zoo/official/gnn/bgcf/train.py index d0b64a6918..01bb0498c1 100644 --- a/model_zoo/official/gnn/bgcf/train.py +++ b/model_zoo/official/gnn/bgcf/train.py @@ -15,7 +15,6 @@ """ BGCF training script. """ -import os import time from mindspore import Tensor @@ -102,12 +101,12 @@ def train(): if __name__ == "__main__": + parser = parser_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", - save_graphs=False) + save_graphs=False, + device_id=parser.device) - parser = parser_args() - os.environ['DEVICE_ID'] = parser.device train_graph, _, sampled_graph_list = load_graph(parser.datapath) train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg)