|
|
|
@@ -23,12 +23,12 @@ from mindspore import log as logger |
|
|
|
DATASET_FILE = "../data/mindrecord/testGraphData/testdata" |
|
|
|
|
|
|
|
|
|
|
|
def graphdata_startserver(): |
|
|
|
def graphdata_startserver(server_port): |
|
|
|
""" |
|
|
|
start graphdata server |
|
|
|
""" |
|
|
|
logger.info('test start server.\n') |
|
|
|
ds.GraphData(DATASET_FILE, 1, 'server') |
|
|
|
ds.GraphData(DATASET_FILE, 1, 'server', port=server_port) |
|
|
|
|
|
|
|
|
|
|
|
class RandomBatchedSampler(ds.Sampler): |
|
|
|
@@ -83,11 +83,13 @@ def test_graphdata_distributed(): |
|
|
|
""" |
|
|
|
logger.info('test distributed.\n') |
|
|
|
|
|
|
|
p1 = Process(target=graphdata_startserver) |
|
|
|
server_port = random.randint(10000, 60000) |
|
|
|
|
|
|
|
p1 = Process(target=graphdata_startserver, args=(server_port,)) |
|
|
|
p1.start() |
|
|
|
time.sleep(2) |
|
|
|
|
|
|
|
g = ds.GraphData(DATASET_FILE, 1, 'client') |
|
|
|
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port) |
|
|
|
nodes = g.get_all_nodes(1) |
|
|
|
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] |
|
|
|
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) |
|
|
|
|