Browse Source

fix gnn random walk pr 1977 comments

add fix to random resize decode crop test case

fix pylint issues
tags/v0.6.0-beta
tony_liu2 5 years ago
parent
commit
14899a1410
7 changed files with 96 additions and 16 deletions
  1. +12
    -0
      mindspore/ccsrc/dataset/engine/gnn/graph.cc
  2. +1
    -1
      mindspore/ccsrc/dataset/engine/gnn/graph.h
  3. +4
    -3
      mindspore/dataset/engine/graphdata.py
  4. +4
    -0
      mindspore/dataset/engine/validators.py
  5. +27
    -1
      tests/ut/cpp/dataset/gnn_graph_test.cc
  6. +3
    -3
      tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc
  7. +45
    -8
      tests/ut/python/dataset/test_graphdata.py

+ 12
- 0
mindspore/ccsrc/dataset/engine/gnn/graph.cc View File

@@ -439,6 +439,18 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co
", step_away_param: " + std::to_string(step_away_param);
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (default_node < -1) {
std::string err_msg = "Failed, default_node required to be greater or equal to -1.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (num_walks <= 0) {
std::string err_msg = "Failed, num_walks parameter required to be greater than 0";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (num_workers <= 0) {
std::string err_msg = "Failed, num_workers parameter required to be greater than 0";
RETURN_STATUS_UNEXPECTED(err_msg);
}
step_home_param_ = step_home_param;
step_away_param_ = step_away_param;
default_node_ = default_node;


+ 1
- 1
mindspore/ccsrc/dataset/engine/gnn/graph.h View File

@@ -181,7 +181,7 @@ class Graph {
float step_away_param_; // Inout hyper parameter. Default is 1.0
NodeIdType default_node_;

int32_t num_walks_; // Number of walks per source. Default is 10
int32_t num_walks_; // Number of walks per source. Default is 1
int32_t num_workers_; // The number of worker threads. Default is 1
};



+ 4
- 3
mindspore/dataset/engine/graphdata.py View File

@@ -232,9 +232,10 @@ class GraphData:
Args:
target_nodes (list[int]): Start node list in random walk
meta_path (list[int]): node type for each walk step
step_home_param (float): return hyper parameter in node2vec algorithm
step_away_param (float): inout hyper parameter in node2vec algorithm
default_node (int): default node if no more neighbors found
step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0).
step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0).
default_node (int, optional): default node if no more neighbors found (Default = -1).
A default value of -1 indicates that no node is given.

Returns:
numpy.ndarray: array of nodes.


+ 4
- 0
mindspore/dataset/engine/validators.py View File

@@ -1260,6 +1260,10 @@ def check_gnn_random_walk(method):
# check meta_path; required argument
check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path')

check_type(param_dict.get("step_home_param"), 'step_home_param', float)
check_type(param_dict.get("step_away_param"), 'step_away_param', float)
check_type(param_dict.get("default_node"), 'default_node', int)

return method(*args, **kwargs)

return new_method


+ 27
- 1
tests/ut/cpp/dataset/gnn_graph_test.cc View File

@@ -247,4 +247,30 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
}
}

TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());

MetaInfo meta_info;
s = graph.GetMetaInfo(&meta_info);
EXPECT_TRUE(s.IsOk());

std::shared_ptr<Tensor> nodes;
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
EXPECT_TRUE(s.IsOk());
std::vector<NodeIdType> node_list;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
node_list.push_back(*itr);
}

print_int_vec(node_list, "node list ");
std::vector<NodeType> meta_path(59, 1);
std::shared_ptr<Tensor> walk_path;
s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
}

+ 3
- 3
tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc View File

@@ -54,7 +54,7 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp2) {
auto decode_and_crop = static_cast<RandomCropAndResizeOp>(crop_and_decode_copy);
EXPECT_TRUE(crop_and_decode.OneToOne());
GlobalContext::config_manager()->set_seed(42);
for (int k = 0; k < 100; k++) {
for (int k = 0; k < 10; k++) {
(void)crop_and_decode.Compute(raw_input_tensor_, &crop_and_decode_output);
(void)decode_and_crop.Compute(input_tensor_, &decode_and_crop_output);
cv::Mat output1 = CVTensor::AsCVTensor(crop_and_decode_output)->mat().clone();
@@ -104,10 +104,10 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp1) {
int mse_sum, m1, m2, count;
double mse;

for (int k = 0; k < 100; ++k) {
for (int k = 0; k < 10; ++k) {
mse_sum = 0;
count = 0;
for (auto i = 0; i < 100; i++) {
for (auto i = 0; i < 10; i++) {
scale = rd_scale(rd);
aspect = rd_aspect(rd);
crop_width = std::round(std::sqrt(h * w * scale / aspect));


+ 45
- 8
tests/ut/python/dataset/test_graphdata.py View File

@@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"


def test_graphdata_getfullneighbor():
"""
Test get all neighbors
"""
logger.info('test get all neighbors.\n')
g = ds.GraphData(DATASET_FILE, 2)
nodes = g.get_all_nodes(1)
assert len(nodes) == 10
@@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor():


def test_graphdata_getnodefeature_input_check():
"""
Test get node feature input check
"""
logger.info('test getnodefeature input check.\n')
g = ds.GraphData(DATASET_FILE)
with pytest.raises(TypeError):
input_list = [1, [1, 1]]
@@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check():


def test_graphdata_getsampledneighbors():
"""
Test sampled neighbors
"""
logger.info('test get sampled neighbors.\n')
g = ds.GraphData(DATASET_FILE, 1)
edges = g.get_all_edges(0)
nodes = g.get_nodes_from_edges(edges)
@@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors():


def test_graphdata_getnegsampledneighbors():
"""
Test neg sampled neighbors
"""
logger.info('test get negative sampled neighbors.\n')
g = ds.GraphData(DATASET_FILE, 2)
nodes = g.get_all_nodes(1)
assert len(nodes) == 10
@@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors():


def test_graphdata_graphinfo():
"""
Test graph info
"""
logger.info('test graph info.\n')
g = ds.GraphData(DATASET_FILE, 2)
graph_info = g.graph_info()
assert graph_info['node_type'] == [1, 2]
@@ -155,6 +175,10 @@ class GNNGraphDataset():


def test_graphdata_generatordataset():
"""
Test generator dataset
"""
logger.info('test generator dataset.\n')
g = ds.GraphData(DATASET_FILE)
batch_num = 2
edge_num = g.graph_info()['edge_num'][0]
@@ -173,7 +197,11 @@ def test_graphdata_generatordataset():
assert i == 40


def test_graphdata_randomwalk():
def test_graphdata_randomwalkdefault():
"""
Test random walk defaults
"""
logger.info('test randomwalk with default parameters.\n')
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
@@ -184,18 +212,27 @@ def test_graphdata_randomwalk():
assert walks.shape == (33, 40)


def test_graphdata_randomwalk():
"""
Test random walk
"""
logger.info('test random walk with given parameters.\n')
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
assert len(nodes) == 33

meta_path = [1 for _ in range(39)]
walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1)
assert walks.shape == (33, 40)


if __name__ == '__main__':
test_graphdata_getfullneighbor()
logger.info('test_graphdata_getfullneighbor Ended.\n')
test_graphdata_getnodefeature_input_check()
logger.info('test_graphdata_getnodefeature_input_check Ended.\n')
test_graphdata_getsampledneighbors()
logger.info('test_graphdata_getsampledneighbors Ended.\n')
test_graphdata_getnegsampledneighbors()
logger.info('test_graphdata_getnegsampledneighbors Ended.\n')
test_graphdata_graphinfo()
logger.info('test_graphdata_graphinfo Ended.\n')
test_graphdata_generatordataset()
logger.info('test_graphdata_generatordataset Ended.\n')
test_graphdata_randomwalkdefault()
test_graphdata_randomwalk()
logger.info('test_graphdata_randomwalk Ended.\n')

Loading…
Cancel
Save