|
|
|
@@ -15,7 +15,6 @@ |
|
|
|
""" |
|
|
|
BGCF training script. |
|
|
|
""" |
|
|
|
import os |
|
|
|
import time |
|
|
|
|
|
|
|
from mindspore import Tensor |
|
|
|
@@ -102,12 +101,12 @@ def train(): |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = parser_args() |
|
|
|
context.set_context(mode=context.GRAPH_MODE, |
|
|
|
device_target="Ascend", |
|
|
|
save_graphs=False) |
|
|
|
save_graphs=False, |
|
|
|
device_id=parser.device) |
|
|
|
|
|
|
|
parser = parser_args() |
|
|
|
os.environ['DEVICE_ID'] = parser.device |
|
|
|
train_graph, _, sampled_graph_list = load_graph(parser.datapath) |
|
|
|
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, |
|
|
|
num_samples=parser.raw_neighs, num_bgcn_neigh=parser.gnew_neighs, num_neg=parser.num_neg) |
|
|
|
|