From: @lizhenglong1992 Reviewed-by: Signed-off-by:pull/14874/MERGE
| @@ -57,6 +57,12 @@ PYBIND_REGISTER( | |||
| THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_edges_from_nodes", | |||
| [](gnn::GraphData &g, std::vector<std::pair<gnn::NodeIdType, gnn::NodeIdType>> node_list) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetEdgesFromNodes(node_list, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_all_neighbors", | |||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| @@ -50,12 +50,13 @@ enum GnnOpName { | |||
| GET_ALL_NODES = 0; | |||
| GET_ALL_EDGES = 1; | |||
| GET_NODES_FROM_EDGES = 2; | |||
| GET_ALL_NEIGHBORS = 3; | |||
| GET_SAMPLED_NEIGHBORS = 4; | |||
| GET_NEG_SAMPLED_NEIGHBORS = 5; | |||
| RANDOM_WALK = 6; | |||
| GET_NODE_FEATURE = 7; | |||
| GET_EDGE_FEATURE = 8; | |||
| GET_EDGES_FROM_NODES = 3; | |||
| GET_ALL_NEIGHBORS = 4; | |||
| GET_SAMPLED_NEIGHBORS = 5; | |||
| GET_NEG_SAMPLED_NEIGHBORS = 6; | |||
| RANDOM_WALK = 7; | |||
| GET_NODE_FEATURE = 8; | |||
| GET_EDGE_FEATURE = 9; | |||
| } | |||
| message GnnRandomWalkPb { | |||
| @@ -64,6 +65,11 @@ message GnnRandomWalkPb { | |||
| int32 default_id = 3; | |||
| } | |||
| message IdPairPb { | |||
| int32 src_id = 1; | |||
| int32 dst_id = 2; | |||
| } | |||
| message GnnGraphDataRequestPb { | |||
| GnnOpName op_name = 1; | |||
| repeated int32 id = 2; // node id or edge id | |||
| @@ -72,6 +78,7 @@ message GnnGraphDataRequestPb { | |||
| TensorPb id_tensor = 5; // input ids ,node id or edge id | |||
| GnnRandomWalkPb random_walk = 6; | |||
| int32 strategy = 7; | |||
| repeated IdPairPb node_pair = 8; | |||
| } | |||
| message GnnGraphDataResponsePb { | |||
| @@ -62,6 +62,13 @@ class GraphData { | |||
| // @return Status The status code returned | |||
| virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | |||
| // Get the edge id from connected node pair | |||
| // @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The status code that indicate the result of function execution | |||
| virtual Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||
| std::shared_ptr<Tensor> *out) = 0; | |||
| // All neighbors of the acquisition node. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | |||
| @@ -120,6 +120,25 @@ Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_li | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||
| std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_EDGES_FROM_NODES); | |||
| for (const auto &pair_node_id : node_list) { | |||
| IdPairPb *proto_pair(request.add_node_pair()); | |||
| proto_pair->set_src_id(static_cast<google::protobuf::int32>(pair_node_id.first)); | |||
| proto_pair->set_dst_id(static_cast<google::protobuf::int32>(pair_node_id.second)); | |||
| } | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| @@ -72,6 +72,13 @@ class GraphDataClient : public GraphData { | |||
| // @return Status The status code returned | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||
| // Get the edge id from connected node pair | |||
| // @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The status code that indicate the result of function execution | |||
| Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // All neighbors of the acquisition node. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | |||
| @@ -128,6 +128,30 @@ Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||
| std::shared_ptr<Tensor> *out) { | |||
| if (node_list.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Input node list is empty."); | |||
| } | |||
| std::vector<std::vector<EdgeIdType>> edge_list; | |||
| edge_list.reserve(node_list.size()); | |||
| for (const auto &node_id : node_list) { | |||
| std::shared_ptr<Node> src_node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_id.first, &src_node)); | |||
| EdgeIdType *edge_id = nullptr; | |||
| src_node->GetEdgeByAdjNodeId(node_id.second, &edge_id); | |||
| std::vector<EdgeIdType> connection_edge = {*edge_id}; | |||
| edge_list.emplace_back(std::move(connection_edge)); | |||
| } | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>(edge_list, DataType(DataType::DE_INT32), out)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| @@ -66,6 +66,13 @@ class GraphDataImpl : public GraphData { | |||
| // @return Status The status code returned | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||
| // Get the edge id from connected node pair | |||
| // @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The status code that indicate the result of function execution | |||
| Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // All neighbors of the acquisition node. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | |||
| @@ -17,6 +17,7 @@ | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | |||
| @@ -31,6 +32,7 @@ static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = { | |||
| {GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes}, | |||
| {GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges}, | |||
| {GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges}, | |||
| {GET_EDGES_FROM_NODES, &GraphDataServiceImpl::GetEdgesFromNodes}, | |||
| {GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors}, | |||
| {GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors}, | |||
| {GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors}, | |||
| @@ -189,6 +191,27 @@ Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *requ | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->node_pair_size() > 0, "The input node pair id list is empty."); | |||
| std::vector<std::pair<NodeIdType, NodeIdType>> node_list; | |||
| node_list.resize(request->node_pair().size()); | |||
| std::transform( | |||
| request->node_pair().begin(), request->node_pair().end(), node_list.begin(), [](const auto &node_pair_id) { | |||
| auto cur_pair = | |||
| std::make_pair(static_cast<NodeIdType>(node_pair_id.src_id()), static_cast<NodeIdType>(node_pair_id.dst_id())); | |||
| return cur_pair; | |||
| }); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetEdgesFromNodes(node_list, &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | |||
| @@ -50,6 +50,7 @@ class GraphDataServiceImpl { | |||
| Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| @@ -61,10 +61,14 @@ Status GraphLoader::GetNodesAndEdges() { | |||
| std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p; | |||
| RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); | |||
| auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); | |||
| RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); | |||
| RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight())); | |||
| RETURN_IF_NOT_OK(src_itr->second->AddAdjacent(dst_itr->second, edge_ptr)); | |||
| e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ | |||
| graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); | |||
| dq.pop_front(); | |||
| @@ -131,6 +131,26 @@ Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightTyp | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) { | |||
| auto node_id = node->id(); | |||
| auto edge_id = edge->id(); | |||
| adjacent_nodes_.insert({node_id, edge_id}); | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) { | |||
| auto itr = adjacent_nodes_.find(adj_node_id); | |||
| if (itr != adjacent_nodes_.end()) { | |||
| (*out_edge_id) = &(itr->second); | |||
| } else { | |||
| (*out_edge_id) = new EdgeIdType(-1); | |||
| MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node."; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) { | |||
| auto itr = features_.find(feature->type()); | |||
| if (itr != features_.end()) { | |||
| @@ -65,6 +65,18 @@ class LocalNode : public Node { | |||
| // @return Status The status code returned | |||
| Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override; | |||
| // Add adjacent node and relative edge for source node | |||
| // @param std::shared_ptr<Node> node - the node to be inserted into adjacent table | |||
| // @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node | |||
| // @return Status - The status code that indicate the result of function execution | |||
| Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) override; | |||
| // Get relative connecting edge of adjacent node by node id | |||
| // @param NodeIdType - The id of adjacent node to be processed | |||
| // @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge | |||
| // @return Status - The status code that indicate the result of function execution | |||
| Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) override; | |||
| // Update feature of node | |||
| // @param std::shared_ptr<Feature> feature - | |||
| // @return Status The status code returned | |||
| @@ -81,6 +93,7 @@ class LocalNode : public Node { | |||
| std::mt19937 rnd_; | |||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; | |||
| std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_; | |||
| std::unordered_map<NodeIdType, EdgeIdType> adjacent_nodes_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| @@ -29,9 +29,12 @@ namespace gnn { | |||
| using NodeType = int8_t; | |||
| using NodeIdType = int32_t; | |||
| using WeightType = float; | |||
| using EdgeIdType = int32_t; | |||
| constexpr NodeIdType kDefaultNodeId = -1; | |||
| class Edge; | |||
| class Node { | |||
| public: | |||
| // Constructor | |||
| @@ -78,6 +81,18 @@ class Node { | |||
| // @return Status The status code returned | |||
| virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0; | |||
| // Add adjacent node and relative edge for source node | |||
| // @param std::shared_ptr<Node> node - the node to be inserted into adjacent table | |||
| // @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node | |||
| // @return Status - The status code that indicate the result of function execution | |||
| virtual Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) = 0; | |||
| // Get relative connecting edge of adjacent node by node id | |||
| // @param NodeIdType - The id of adjacent node to be processed | |||
| // @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge | |||
| // @return Status - The status code that indicate the result of function execution | |||
| virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) = 0; | |||
| // Update feature of node | |||
| // @param std::shared_ptr<Feature> feature - | |||
| // @return Status The status code returned | |||
| @@ -102,8 +102,7 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr<Tensor> &input, std::share | |||
| unsigned char *ret_ptr = data.get(); | |||
| std::shared_ptr<DvppDataInfo> DecodeOut(process.Get_Decode_DeviceData()); | |||
| dsize_t dvpp_length = DecodeOut->dataSize; | |||
| // dsize_t decode_height = DecodeOut->height; | |||
| // dsize_t decode_width = DecodeOut->width; | |||
| const TensorShape dvpp_shape({dvpp_length, 1, 1}); | |||
| const DataType dvpp_data_type(DataType::DE_UINT8); | |||
| mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); | |||
| @@ -391,6 +391,38 @@ def validate_dataset_param_value(param_list, param_dict, param_type): | |||
| type_check(param_dict.get(param_name), (param_type,), param_name) | |||
| def check_gnn_list_of_pair_or_ndarray(param, param_name): | |||
| """ | |||
| Check if the input parameter is a list of tuple or numpy.ndarray. | |||
| Args: | |||
| param (Union[list[tuple], nd.ndarray]): param. | |||
| param_name (str): param_name. | |||
| Returns: | |||
| Exception: TypeError if error. | |||
| """ | |||
| type_check(param, (list, np.ndarray), param_name) | |||
| if isinstance(param, list): | |||
| param_names = ["pair_{0}".format(i) for i in range(len(param))] | |||
| type_check_list(param, (tuple,), param_names) | |||
| for idx, pair in enumerate(param): | |||
| if not len(pair) == 2: | |||
| raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( | |||
| param_names[idx], len(pair))) | |||
| column_names = ["element_{0}".format(i) for i in range(len(pair))] | |||
| type_check_list(pair, (int,), column_names) | |||
| elif isinstance(param, np.ndarray): | |||
| if param.ndim != 2: | |||
| raise ValueError("Input ndarray must be in dimension 2. Got {0}".format(param.ndim)) | |||
| if param.shape[1] != 2: | |||
| raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( | |||
| param_name, param.shape[1])) | |||
| if not param.dtype == np.int32: | |||
| raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( | |||
| param_name, param.dtype)) | |||
| def check_gnn_list_or_ndarray(param, param_name): | |||
| """ | |||
| Check if the input parameter is list or numpy.ndarray. | |||
| @@ -26,9 +26,9 @@ from mindspore._c_dataengine import Tensor | |||
| from mindspore._c_dataengine import SamplingStrategy as Sampling | |||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | |||
| check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | |||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \ | |||
| check_gnn_random_walk | |||
| check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \ | |||
| check_gnn_get_sampled_neighbors, check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, \ | |||
| check_gnn_get_edge_feature, check_gnn_random_walk | |||
| class SamplingStrategy(IntEnum): | |||
| @@ -163,6 +163,27 @@ class GraphData: | |||
| raise Exception("This method is not supported when working mode is server.") | |||
| return self._graph_data.get_nodes_from_edges(edge_list).as_array() | |||
| @check_gnn_get_edges_from_nodes | |||
| def get_edges_from_nodes(self, node_list): | |||
| """ | |||
| Get edges from the nodes. | |||
| Args: | |||
| node_list (Union[list[tuple], numpy.ndarray]): The given list of pair nodes ID. | |||
| Returns: | |||
| numpy.ndarray, array of edgs ID. | |||
| Examples: | |||
| >>> edges = graph_dataset.get_edges_from_nodes([(1, 3), (5, 2)]) | |||
| Raises: | |||
| TypeError: If `edge_list` is not list or ndarray. | |||
| """ | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server.") | |||
| return self._graph_data.get_edges_from_nodes(node_list).as_array() | |||
| @check_gnn_get_all_neighbors | |||
| def get_all_neighbors(self, node_list, neighbor_type): | |||
| """ | |||
| @@ -25,8 +25,8 @@ import numpy as np | |||
| from mindspore._c_expression import typing | |||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | |||
| INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | |||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | |||
| check_columns, check_pos_int32, check_valid_str | |||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \ | |||
| check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str | |||
| from . import datasets | |||
| from . import samplers | |||
| @@ -1090,6 +1090,19 @@ def check_gnn_get_nodes_from_edges(method): | |||
| return new_method | |||
| def check_gnn_get_edges_from_nodes(method): | |||
| """A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [node_list], _ = parse_user_args(method, *args, **kwargs) | |||
| check_gnn_list_of_pair_or_ndarray(node_list, "node_list") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_gnn_get_all_neighbors(method): | |||
| """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function.""" | |||
| @@ -95,6 +95,21 @@ class MindDataTestGNNGraph : public UT::Common { | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| GraphDataImpl graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208}, | |||
| {110, 201}, {204, 105}, {208, 108}}; | |||
| std::shared_ptr<Tensor> edges; | |||
| s = graph.GetEdgesFromNodes(src_dst_list, &edges); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]"); | |||
| } | |||
| TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| GraphDataImpl graph(path, 1); | |||
| @@ -241,6 +241,18 @@ def test_graphdata_getedgefeature(): | |||
| assert features[1].shape == (40,) | |||
| def test_graphdata_getedgesfromnodes(): | |||
| """ | |||
| Test get edges from nodes | |||
| """ | |||
| logger.info('test get_edges_from_nodes\n') | |||
| g = ds.GraphData(DATASET_FILE) | |||
| nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)] | |||
| edges = g.get_edges_from_nodes(nodes_pair_list) | |||
| assert edges.tolist() == [1, 9, 31, 17, 20, 40] | |||
| if __name__ == '__main__': | |||
| test_graphdata_getfullneighbor() | |||
| test_graphdata_getnodefeature_input_check() | |||
| @@ -251,3 +263,4 @@ if __name__ == '__main__': | |||
| test_graphdata_randomwalkdefault() | |||
| test_graphdata_randomwalk() | |||
| test_graphdata_getedgefeature() | |||
| test_graphdata_getedgesfromnodes() | |||
| @@ -112,6 +112,10 @@ def test_graphdata_distributed(): | |||
| assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, | |||
| 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] | |||
| nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)] | |||
| edges = g.get_edges_from_nodes(nodes_pair_list) | |||
| assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37] | |||
| batch_num = 2 | |||
| edge_num = g.graph_info()['edge_num'][0] | |||
| out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] | |||