add fix to random resize decode crop test case fix pylint issuestags/v0.6.0-beta
| @@ -439,6 +439,18 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co | |||
| ", step_away_param: " + std::to_string(step_away_param); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| if (default_node < -1) { | |||
| std::string err_msg = "Failed, default_node required to be greater or equal to -1."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| if (num_walks <= 0) { | |||
| std::string err_msg = "Failed, num_walks parameter required to be greater than 0"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| if (num_workers <= 0) { | |||
| std::string err_msg = "Failed, num_workers parameter required to be greater than 0"; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| step_home_param_ = step_home_param; | |||
| step_away_param_ = step_away_param; | |||
| default_node_ = default_node; | |||
| @@ -181,7 +181,7 @@ class Graph { | |||
| float step_away_param_; // Inout hyper parameter. Default is 1.0 | |||
| NodeIdType default_node_; | |||
| int32_t num_walks_; // Number of walks per source. Default is 10 | |||
| int32_t num_walks_; // Number of walks per source. Default is 1 | |||
| int32_t num_workers_; // The number of worker threads. Default is 1 | |||
| }; | |||
| @@ -232,9 +232,10 @@ class GraphData: | |||
| Args: | |||
| target_nodes (list[int]): Start node list in random walk | |||
| meta_path (list[int]): node type for each walk step | |||
| step_home_param (float): return hyper parameter in node2vec algorithm | |||
| step_away_param (float): inout hyper parameter in node2vec algorithm | |||
| default_node (int): default node if no more neighbors found | |||
| step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0). | |||
| step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0). | |||
| default_node (int, optional): default node if no more neighbors found (Default = -1). | |||
| A default value of -1 indicates that no node is given. | |||
| Returns: | |||
| numpy.ndarray: array of nodes. | |||
| @@ -1260,6 +1260,10 @@ def check_gnn_random_walk(method): | |||
| # check meta_path; required argument | |||
| check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') | |||
| check_type(param_dict.get("step_home_param"), 'step_home_param', float) | |||
| check_type(param_dict.get("step_away_param"), 'step_away_param', float) | |||
| check_type(param_dict.get("default_node"), 'default_node', int) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| @@ -247,4 +247,30 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | |||
| s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); | |||
| } | |||
| } | |||
| TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { | |||
| std::string path = "data/mindrecord/testGraphData/sns"; | |||
| Graph graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| MetaInfo meta_info; | |||
| s = graph.GetMetaInfo(&meta_info); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::shared_ptr<Tensor> nodes; | |||
| s = graph.GetAllNodes(meta_info.node_type[0], &nodes); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::vector<NodeIdType> node_list; | |||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||
| node_list.push_back(*itr); | |||
| } | |||
| print_int_vec(node_list, "node list "); | |||
| std::vector<NodeType> meta_path(59, 1); | |||
| std::shared_ptr<Tensor> walk_path; | |||
| s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); | |||
| } | |||
| @@ -54,7 +54,7 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp2) { | |||
| auto decode_and_crop = static_cast<RandomCropAndResizeOp>(crop_and_decode_copy); | |||
| EXPECT_TRUE(crop_and_decode.OneToOne()); | |||
| GlobalContext::config_manager()->set_seed(42); | |||
| for (int k = 0; k < 100; k++) { | |||
| for (int k = 0; k < 10; k++) { | |||
| (void)crop_and_decode.Compute(raw_input_tensor_, &crop_and_decode_output); | |||
| (void)decode_and_crop.Compute(input_tensor_, &decode_and_crop_output); | |||
| cv::Mat output1 = CVTensor::AsCVTensor(crop_and_decode_output)->mat().clone(); | |||
| @@ -104,10 +104,10 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp1) { | |||
| int mse_sum, m1, m2, count; | |||
| double mse; | |||
| for (int k = 0; k < 100; ++k) { | |||
| for (int k = 0; k < 10; ++k) { | |||
| mse_sum = 0; | |||
| count = 0; | |||
| for (auto i = 0; i < 100; i++) { | |||
| for (auto i = 0; i < 10; i++) { | |||
| scale = rd_scale(rd); | |||
| aspect = rd_aspect(rd); | |||
| crop_width = std::round(std::sqrt(h * w * scale / aspect)); | |||
| @@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" | |||
| def test_graphdata_getfullneighbor(): | |||
| """ | |||
| Test get all neighbors | |||
| """ | |||
| logger.info('test get all neighbors.\n') | |||
| g = ds.GraphData(DATASET_FILE, 2) | |||
| nodes = g.get_all_nodes(1) | |||
| assert len(nodes) == 10 | |||
| @@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor(): | |||
| def test_graphdata_getnodefeature_input_check(): | |||
| """ | |||
| Test get node feature input check | |||
| """ | |||
| logger.info('test getnodefeature input check.\n') | |||
| g = ds.GraphData(DATASET_FILE) | |||
| with pytest.raises(TypeError): | |||
| input_list = [1, [1, 1]] | |||
| @@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check(): | |||
| def test_graphdata_getsampledneighbors(): | |||
| """ | |||
| Test sampled neighbors | |||
| """ | |||
| logger.info('test get sampled neighbors.\n') | |||
| g = ds.GraphData(DATASET_FILE, 1) | |||
| edges = g.get_all_edges(0) | |||
| nodes = g.get_nodes_from_edges(edges) | |||
| @@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors(): | |||
| def test_graphdata_getnegsampledneighbors(): | |||
| """ | |||
| Test neg sampled neighbors | |||
| """ | |||
| logger.info('test get negative sampled neighbors.\n') | |||
| g = ds.GraphData(DATASET_FILE, 2) | |||
| nodes = g.get_all_nodes(1) | |||
| assert len(nodes) == 10 | |||
| @@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors(): | |||
| def test_graphdata_graphinfo(): | |||
| """ | |||
| Test graph info | |||
| """ | |||
| logger.info('test graph info.\n') | |||
| g = ds.GraphData(DATASET_FILE, 2) | |||
| graph_info = g.graph_info() | |||
| assert graph_info['node_type'] == [1, 2] | |||
| @@ -155,6 +175,10 @@ class GNNGraphDataset(): | |||
| def test_graphdata_generatordataset(): | |||
| """ | |||
| Test generator dataset | |||
| """ | |||
| logger.info('test generator dataset.\n') | |||
| g = ds.GraphData(DATASET_FILE) | |||
| batch_num = 2 | |||
| edge_num = g.graph_info()['edge_num'][0] | |||
| @@ -173,7 +197,11 @@ def test_graphdata_generatordataset(): | |||
| assert i == 40 | |||
| def test_graphdata_randomwalk(): | |||
| def test_graphdata_randomwalkdefault(): | |||
| """ | |||
| Test random walk defaults | |||
| """ | |||
| logger.info('test randomwalk with default parameters.\n') | |||
| g = ds.GraphData(SOCIAL_DATA_FILE, 1) | |||
| nodes = g.get_all_nodes(1) | |||
| print(len(nodes)) | |||
| @@ -184,18 +212,27 @@ def test_graphdata_randomwalk(): | |||
| assert walks.shape == (33, 40) | |||
| def test_graphdata_randomwalk(): | |||
| """ | |||
| Test random walk | |||
| """ | |||
| logger.info('test random walk with given parameters.\n') | |||
| g = ds.GraphData(SOCIAL_DATA_FILE, 1) | |||
| nodes = g.get_all_nodes(1) | |||
| print(len(nodes)) | |||
| assert len(nodes) == 33 | |||
| meta_path = [1 for _ in range(39)] | |||
| walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1) | |||
| assert walks.shape == (33, 40) | |||
| if __name__ == '__main__': | |||
| test_graphdata_getfullneighbor() | |||
| logger.info('test_graphdata_getfullneighbor Ended.\n') | |||
| test_graphdata_getnodefeature_input_check() | |||
| logger.info('test_graphdata_getnodefeature_input_check Ended.\n') | |||
| test_graphdata_getsampledneighbors() | |||
| logger.info('test_graphdata_getsampledneighbors Ended.\n') | |||
| test_graphdata_getnegsampledneighbors() | |||
| logger.info('test_graphdata_getnegsampledneighbors Ended.\n') | |||
| test_graphdata_graphinfo() | |||
| logger.info('test_graphdata_graphinfo Ended.\n') | |||
| test_graphdata_generatordataset() | |||
| logger.info('test_graphdata_generatordataset Ended.\n') | |||
| test_graphdata_randomwalkdefault() | |||
| test_graphdata_randomwalk() | |||
| logger.info('test_graphdata_randomwalk Ended.\n') | |||