| @@ -15,7 +15,6 @@ | |||||
| """ | """ | ||||
| BGCF training script. | BGCF training script. | ||||
| """ | """ | ||||
| import os | |||||
| import time | import time | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -102,12 +101,12 @@ def train(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| parser = parser_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="Ascend", | device_target="Ascend", | ||||
| save_graphs=False) | |||||
| save_graphs=False, | |||||
| device_id=parser.device) | |||||
| parser = parser_args() | |||||
| os.environ['DEVICE_ID'] = parser.device | |||||
| train_graph, _, sampled_graph_list = load_graph(parser.datapath) | train_graph, _, sampled_graph_list = load_graph(parser.datapath) | ||||
| train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, | train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, | ||||
| num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg) | num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg) | ||||