Browse Source

fix device id bug

tags/v1.2.0-rc1
zhanke 4 years ago
parent
commit
432984d86f
1 changed files with 3 additions and 4 deletions
  1. +3
    -4
      model_zoo/official/gnn/bgcf/train.py

+ 3
- 4
model_zoo/official/gnn/bgcf/train.py View File

@@ -15,7 +15,6 @@
""" """
BGCF training script. BGCF training script.
""" """
import os
import time import time


from mindspore import Tensor from mindspore import Tensor
@@ -102,12 +101,12 @@ def train():




if __name__ == "__main__": if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target="Ascend",
save_graphs=False)
save_graphs=False,
device_id=parser.device)


parser = parser_args()
os.environ['DEVICE_ID'] = parser.device
train_graph, _, sampled_graph_list = load_graph(parser.datapath) train_graph, _, sampled_graph_list = load_graph(parser.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,
num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg) num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg)


Loading…
Cancel
Save