| @@ -149,14 +149,37 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::CheckSamplesNum(NodeIdType samples_num) { | |||
| NodeIdType all_nodes_number = | |||
| std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, | |||
| [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); | |||
| if ((samples_num < 1) || (samples_num > all_nodes_number)) { | |||
| std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) + | |||
| ", got " + std::to_string(samples_num); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), | |||
| "The sizes of neighbor_nums and neighbor_types are inconsistent."); | |||
| for (const auto &num : neighbor_nums) { | |||
| RETURN_IF_NOT_OK(CheckSamplesNum(num)); | |||
| } | |||
| for (const auto &type : neighbor_types) { | |||
| if (node_type_map_.find(type) == node_type_map_.end()) { | |||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size()); | |||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | |||
| std::shared_ptr<Node> input_node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node)); | |||
| neighbors_vec[node_idx].emplace_back(node_list[node_idx]); | |||
| std::vector<NodeIdType> input_list = {node_list[node_idx]}; | |||
| for (size_t i = 0; i < neighbor_nums.size(); ++i) { | |||
| @@ -204,6 +227,12 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno | |||
| Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); | |||
| if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) { | |||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::vector<std::vector<NodeIdType>> neighbors_vec; | |||
| neighbors_vec.resize(node_list.size()); | |||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | |||
| @@ -226,6 +226,8 @@ class Graph { | |||
| Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data, | |||
| int32_t samples_num, std::vector<NodeIdType> *out_samples); | |||
| Status CheckSamplesNum(NodeIdType samples_num); | |||
| std::string dataset_file_; | |||
| int32_t num_workers_; // The number of worker threads | |||
| std::mt19937 rnd_; | |||
| @@ -1110,10 +1110,10 @@ def check_gnn_list_or_ndarray(param, param_name): | |||
| for m in param: | |||
| if not isinstance(m, int): | |||
| raise TypeError( | |||
| "Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m))) | |||
| "Each member in {0} should be of type int. Got {1}.".format(param_name, type(m))) | |||
| elif isinstance(param, np.ndarray): | |||
| if not param.dtype == np.int32: | |||
| raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( | |||
| raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( | |||
| param_name, param.dtype)) | |||
| else: | |||
| raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( | |||
| @@ -1196,15 +1196,15 @@ def check_gnn_get_sampled_neighbors(method): | |||
| # check neighbor_nums; required argument | |||
| neighbor_nums = param_dict.get("neighbor_nums") | |||
| check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') | |||
| if len(neighbor_nums) > 6: | |||
| raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( | |||
| if not neighbor_nums or len(neighbor_nums) > 6: | |||
| raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( | |||
| 'neighbor_nums', len(neighbor_nums))) | |||
| # check neighbor_types; required argument | |||
| neighbor_types = param_dict.get("neighbor_types") | |||
| check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') | |||
| if len(neighbor_nums) > 6: | |||
| raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( | |||
| if not neighbor_types or len(neighbor_types) > 6: | |||
| raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( | |||
| 'neighbor_types', len(neighbor_types))) | |||
| if len(neighbor_nums) != len(neighbor_types): | |||
| @@ -1256,7 +1256,7 @@ def check_gnn_random_walk(method): | |||
| return new_method | |||
| def check_aligned_list(param, param_name, membor_type): | |||
| def check_aligned_list(param, param_name, member_type): | |||
| """Check whether the structure of each member of the list is the same.""" | |||
| if not isinstance(param, list): | |||
| @@ -1264,27 +1264,27 @@ def check_aligned_list(param, param_name, membor_type): | |||
| if not param: | |||
| raise TypeError( | |||
| "Parameter {0} or its members are empty".format(param_name)) | |||
| membor_have_list = None | |||
| member_have_list = None | |||
| list_len = None | |||
| for membor in param: | |||
| if isinstance(membor, list): | |||
| check_aligned_list(membor, param_name, membor_type) | |||
| if membor_have_list not in (None, True): | |||
| for member in param: | |||
| if isinstance(member, list): | |||
| check_aligned_list(member, param_name, member_type) | |||
| if member_have_list not in (None, True): | |||
| raise TypeError("The type of each member of the parameter {0} is inconsistent".format( | |||
| param_name)) | |||
| if list_len is not None and len(membor) != list_len: | |||
| if list_len is not None and len(member) != list_len: | |||
| raise TypeError("The size of each member of parameter {0} is inconsistent".format( | |||
| param_name)) | |||
| membor_have_list = True | |||
| list_len = len(membor) | |||
| member_have_list = True | |||
| list_len = len(member) | |||
| else: | |||
| if not isinstance(membor, membor_type): | |||
| raise TypeError("Each membor in {0} should be of type int. Got {1}.".format( | |||
| param_name, type(membor))) | |||
| if membor_have_list not in (None, False): | |||
| if not isinstance(member, member_type): | |||
| raise TypeError("Each member in {0} should be of type int. Got {1}.".format( | |||
| param_name, type(member))) | |||
| if member_have_list not in (None, False): | |||
| raise TypeError("The type of each member of the parameter {0} is inconsistent".format( | |||
| param_name)) | |||
| membor_have_list = False | |||
| member_have_list = False | |||
| def check_gnn_get_node_feature(method): | |||
| @@ -1300,7 +1300,7 @@ def check_gnn_get_node_feature(method): | |||
| check_aligned_list(node_list, 'node_list', int) | |||
| elif isinstance(node_list, np.ndarray): | |||
| if not node_list.dtype == np.int32: | |||
| raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( | |||
| raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( | |||
| node_list, node_list.dtype)) | |||
| else: | |||
| raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( | |||
| @@ -158,6 +158,18 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||
| s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != | |||
| @@ -198,9 +210,17 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | |||
| s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); | |||
| neg_neighbors.reset(); | |||
| s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); | |||
| neg_neighbors.reset(); | |||
| s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); | |||
| neg_neighbors.reset(); | |||
| s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); | |||
| EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); | |||
| } | |||
| TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | |||