add fix to random resize decode crop test case fix pylint issuestags/v0.6.0-beta
| @@ -439,6 +439,18 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co | |||||
| ", step_away_param: " + std::to_string(step_away_param); | ", step_away_param: " + std::to_string(step_away_param); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| if (default_node < -1) { | |||||
| std::string err_msg = "Failed, default_node required to be greater or equal to -1."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| if (num_walks <= 0) { | |||||
| std::string err_msg = "Failed, num_walks parameter required to be greater than 0"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| if (num_workers <= 0) { | |||||
| std::string err_msg = "Failed, num_workers parameter required to be greater than 0"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| step_home_param_ = step_home_param; | step_home_param_ = step_home_param; | ||||
| step_away_param_ = step_away_param; | step_away_param_ = step_away_param; | ||||
| default_node_ = default_node; | default_node_ = default_node; | ||||
| @@ -181,7 +181,7 @@ class Graph { | |||||
| float step_away_param_; // Inout hyper parameter. Default is 1.0 | float step_away_param_; // Inout hyper parameter. Default is 1.0 | ||||
| NodeIdType default_node_; | NodeIdType default_node_; | ||||
| int32_t num_walks_; // Number of walks per source. Default is 10 | |||||
| int32_t num_walks_; // Number of walks per source. Default is 1 | |||||
| int32_t num_workers_; // The number of worker threads. Default is 1 | int32_t num_workers_; // The number of worker threads. Default is 1 | ||||
| }; | }; | ||||
| @@ -232,9 +232,10 @@ class GraphData: | |||||
| Args: | Args: | ||||
| target_nodes (list[int]): Start node list in random walk | target_nodes (list[int]): Start node list in random walk | ||||
| meta_path (list[int]): node type for each walk step | meta_path (list[int]): node type for each walk step | ||||
| step_home_param (float): return hyper parameter in node2vec algorithm | |||||
| step_away_param (float): inout hyper parameter in node2vec algorithm | |||||
| default_node (int): default node if no more neighbors found | |||||
| step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0). | |||||
| step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0). | |||||
| default_node (int, optional): default node if no more neighbors found (Default = -1). | |||||
| A default value of -1 indicates that no node is given. | |||||
| Returns: | Returns: | ||||
| numpy.ndarray: array of nodes. | numpy.ndarray: array of nodes. | ||||
| @@ -1260,6 +1260,10 @@ def check_gnn_random_walk(method): | |||||
| # check meta_path; required argument | # check meta_path; required argument | ||||
| check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') | check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') | ||||
| check_type(param_dict.get("step_home_param"), 'step_home_param', float) | |||||
| check_type(param_dict.get("step_away_param"), 'step_away_param', float) | |||||
| check_type(param_dict.get("default_node"), 'default_node', int) | |||||
| return method(*args, **kwargs) | return method(*args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -247,4 +247,30 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | |||||
| s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); | s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); | EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); | ||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { | |||||
| std::string path = "data/mindrecord/testGraphData/sns"; | |||||
| Graph graph(path, 1); | |||||
| Status s = graph.Init(); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| MetaInfo meta_info; | |||||
| s = graph.GetMetaInfo(&meta_info); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| std::shared_ptr<Tensor> nodes; | |||||
| s = graph.GetAllNodes(meta_info.node_type[0], &nodes); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| std::vector<NodeIdType> node_list; | |||||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||||
| node_list.push_back(*itr); | |||||
| } | |||||
| print_int_vec(node_list, "node list "); | |||||
| std::vector<NodeType> meta_path(59, 1); | |||||
| std::shared_ptr<Tensor> walk_path; | |||||
| s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); | |||||
| } | |||||
| @@ -54,7 +54,7 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp2) { | |||||
| auto decode_and_crop = static_cast<RandomCropAndResizeOp>(crop_and_decode_copy); | auto decode_and_crop = static_cast<RandomCropAndResizeOp>(crop_and_decode_copy); | ||||
| EXPECT_TRUE(crop_and_decode.OneToOne()); | EXPECT_TRUE(crop_and_decode.OneToOne()); | ||||
| GlobalContext::config_manager()->set_seed(42); | GlobalContext::config_manager()->set_seed(42); | ||||
| for (int k = 0; k < 100; k++) { | |||||
| for (int k = 0; k < 10; k++) { | |||||
| (void)crop_and_decode.Compute(raw_input_tensor_, &crop_and_decode_output); | (void)crop_and_decode.Compute(raw_input_tensor_, &crop_and_decode_output); | ||||
| (void)decode_and_crop.Compute(input_tensor_, &decode_and_crop_output); | (void)decode_and_crop.Compute(input_tensor_, &decode_and_crop_output); | ||||
| cv::Mat output1 = CVTensor::AsCVTensor(crop_and_decode_output)->mat().clone(); | cv::Mat output1 = CVTensor::AsCVTensor(crop_and_decode_output)->mat().clone(); | ||||
| @@ -104,10 +104,10 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp1) { | |||||
| int mse_sum, m1, m2, count; | int mse_sum, m1, m2, count; | ||||
| double mse; | double mse; | ||||
| for (int k = 0; k < 100; ++k) { | |||||
| for (int k = 0; k < 10; ++k) { | |||||
| mse_sum = 0; | mse_sum = 0; | ||||
| count = 0; | count = 0; | ||||
| for (auto i = 0; i < 100; i++) { | |||||
| for (auto i = 0; i < 10; i++) { | |||||
| scale = rd_scale(rd); | scale = rd_scale(rd); | ||||
| aspect = rd_aspect(rd); | aspect = rd_aspect(rd); | ||||
| crop_width = std::round(std::sqrt(h * w * scale / aspect)); | crop_width = std::round(std::sqrt(h * w * scale / aspect)); | ||||
| @@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" | |||||
| def test_graphdata_getfullneighbor(): | def test_graphdata_getfullneighbor(): | ||||
| """ | |||||
| Test get all neighbors | |||||
| """ | |||||
| logger.info('test get all neighbors.\n') | |||||
| g = ds.GraphData(DATASET_FILE, 2) | g = ds.GraphData(DATASET_FILE, 2) | ||||
| nodes = g.get_all_nodes(1) | nodes = g.get_all_nodes(1) | ||||
| assert len(nodes) == 10 | assert len(nodes) == 10 | ||||
| @@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor(): | |||||
| def test_graphdata_getnodefeature_input_check(): | def test_graphdata_getnodefeature_input_check(): | ||||
| """ | |||||
| Test get node feature input check | |||||
| """ | |||||
| logger.info('test getnodefeature input check.\n') | |||||
| g = ds.GraphData(DATASET_FILE) | g = ds.GraphData(DATASET_FILE) | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| input_list = [1, [1, 1]] | input_list = [1, [1, 1]] | ||||
| @@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check(): | |||||
| def test_graphdata_getsampledneighbors(): | def test_graphdata_getsampledneighbors(): | ||||
| """ | |||||
| Test sampled neighbors | |||||
| """ | |||||
| logger.info('test get sampled neighbors.\n') | |||||
| g = ds.GraphData(DATASET_FILE, 1) | g = ds.GraphData(DATASET_FILE, 1) | ||||
| edges = g.get_all_edges(0) | edges = g.get_all_edges(0) | ||||
| nodes = g.get_nodes_from_edges(edges) | nodes = g.get_nodes_from_edges(edges) | ||||
| @@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors(): | |||||
| def test_graphdata_getnegsampledneighbors(): | def test_graphdata_getnegsampledneighbors(): | ||||
| """ | |||||
| Test neg sampled neighbors | |||||
| """ | |||||
| logger.info('test get negative sampled neighbors.\n') | |||||
| g = ds.GraphData(DATASET_FILE, 2) | g = ds.GraphData(DATASET_FILE, 2) | ||||
| nodes = g.get_all_nodes(1) | nodes = g.get_all_nodes(1) | ||||
| assert len(nodes) == 10 | assert len(nodes) == 10 | ||||
| @@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors(): | |||||
| def test_graphdata_graphinfo(): | def test_graphdata_graphinfo(): | ||||
| """ | |||||
| Test graph info | |||||
| """ | |||||
| logger.info('test graph info.\n') | |||||
| g = ds.GraphData(DATASET_FILE, 2) | g = ds.GraphData(DATASET_FILE, 2) | ||||
| graph_info = g.graph_info() | graph_info = g.graph_info() | ||||
| assert graph_info['node_type'] == [1, 2] | assert graph_info['node_type'] == [1, 2] | ||||
| @@ -155,6 +175,10 @@ class GNNGraphDataset(): | |||||
| def test_graphdata_generatordataset(): | def test_graphdata_generatordataset(): | ||||
| """ | |||||
| Test generator dataset | |||||
| """ | |||||
| logger.info('test generator dataset.\n') | |||||
| g = ds.GraphData(DATASET_FILE) | g = ds.GraphData(DATASET_FILE) | ||||
| batch_num = 2 | batch_num = 2 | ||||
| edge_num = g.graph_info()['edge_num'][0] | edge_num = g.graph_info()['edge_num'][0] | ||||
| @@ -173,7 +197,11 @@ def test_graphdata_generatordataset(): | |||||
| assert i == 40 | assert i == 40 | ||||
| def test_graphdata_randomwalk(): | |||||
| def test_graphdata_randomwalkdefault(): | |||||
| """ | |||||
| Test random walk defaults | |||||
| """ | |||||
| logger.info('test randomwalk with default parameters.\n') | |||||
| g = ds.GraphData(SOCIAL_DATA_FILE, 1) | g = ds.GraphData(SOCIAL_DATA_FILE, 1) | ||||
| nodes = g.get_all_nodes(1) | nodes = g.get_all_nodes(1) | ||||
| print(len(nodes)) | print(len(nodes)) | ||||
| @@ -184,18 +212,27 @@ def test_graphdata_randomwalk(): | |||||
| assert walks.shape == (33, 40) | assert walks.shape == (33, 40) | ||||
| def test_graphdata_randomwalk(): | |||||
| """ | |||||
| Test random walk | |||||
| """ | |||||
| logger.info('test random walk with given parameters.\n') | |||||
| g = ds.GraphData(SOCIAL_DATA_FILE, 1) | |||||
| nodes = g.get_all_nodes(1) | |||||
| print(len(nodes)) | |||||
| assert len(nodes) == 33 | |||||
| meta_path = [1 for _ in range(39)] | |||||
| walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1) | |||||
| assert walks.shape == (33, 40) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_graphdata_getfullneighbor() | test_graphdata_getfullneighbor() | ||||
| logger.info('test_graphdata_getfullneighbor Ended.\n') | |||||
| test_graphdata_getnodefeature_input_check() | test_graphdata_getnodefeature_input_check() | ||||
| logger.info('test_graphdata_getnodefeature_input_check Ended.\n') | |||||
| test_graphdata_getsampledneighbors() | test_graphdata_getsampledneighbors() | ||||
| logger.info('test_graphdata_getsampledneighbors Ended.\n') | |||||
| test_graphdata_getnegsampledneighbors() | test_graphdata_getnegsampledneighbors() | ||||
| logger.info('test_graphdata_getnegsampledneighbors Ended.\n') | |||||
| test_graphdata_graphinfo() | test_graphdata_graphinfo() | ||||
| logger.info('test_graphdata_graphinfo Ended.\n') | |||||
| test_graphdata_generatordataset() | test_graphdata_generatordataset() | ||||
| logger.info('test_graphdata_generatordataset Ended.\n') | |||||
| test_graphdata_randomwalkdefault() | |||||
| test_graphdata_randomwalk() | test_graphdata_randomwalk() | ||||
| logger.info('test_graphdata_randomwalk Ended.\n') | |||||