diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc index 5e8c5e3a1c..46e8417557 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc @@ -57,6 +57,12 @@ PYBIND_REGISTER( THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); return out; }) + .def("get_edges_from_nodes", + [](gnn::GraphData &g, std::vector> node_list) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetEdgesFromNodes(node_list, &out)); + return out; + }) .def("get_all_neighbors", [](gnn::GraphData &g, std::vector node_list, gnn::NodeType neighbor_type) { std::shared_ptr out; diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto index f95a823eb7..0c6a92d1c2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto @@ -50,12 +50,13 @@ enum GnnOpName { GET_ALL_NODES = 0; GET_ALL_EDGES = 1; GET_NODES_FROM_EDGES = 2; - GET_ALL_NEIGHBORS = 3; - GET_SAMPLED_NEIGHBORS = 4; - GET_NEG_SAMPLED_NEIGHBORS = 5; - RANDOM_WALK = 6; - GET_NODE_FEATURE = 7; - GET_EDGE_FEATURE = 8; + GET_EDGES_FROM_NODES = 3; + GET_ALL_NEIGHBORS = 4; + GET_SAMPLED_NEIGHBORS = 5; + GET_NEG_SAMPLED_NEIGHBORS = 6; + RANDOM_WALK = 7; + GET_NODE_FEATURE = 8; + GET_EDGE_FEATURE = 9; } message GnnRandomWalkPb { @@ -64,6 +65,11 @@ message GnnRandomWalkPb { int32 default_id = 3; } +message IdPairPb { + int32 src_id = 1; + int32 dst_id = 2; +} + message GnnGraphDataRequestPb { GnnOpName op_name = 1; repeated int32 id = 2; // node id or edge id @@ -72,6 +78,7 @@ message GnnGraphDataRequestPb { TensorPb id_tensor = 5; // input ids ,node id or edge id GnnRandomWalkPb random_walk = 6; int32 strategy = 7; + repeated IdPairPb node_pair = 8; } message GnnGraphDataResponsePb { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h index c50bf194dd..965b45a30b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h @@ -62,6 +62,13 @@ class GraphData { // @return Status The status code returned virtual Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) = 0; + // Get the edge id from connected node pair + // @param std::vector> node_list - List of pair nodes + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The status code that indicate the result of function execution + virtual Status GetEdgesFromNodes(const std::vector> &node_list, + std::shared_ptr *out) = 0; + // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc index a9f618ccbe..5ca02757a5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc @@ -120,6 +120,25 @@ Status GraphDataClient::GetNodesFromEdges(const std::vector &edge_li return Status::OK(); } +Status GraphDataClient::GetEdgesFromNodes(const std::vector> &node_list, + std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + + request.set_op_name(GET_EDGES_FROM_NODES); + + for (const auto &pair_node_id : node_list) { + IdPairPb *proto_pair(request.add_node_pair()); + proto_pair->set_src_id(static_cast(pair_node_id.first)); + proto_pair->set_dst_id(static_cast(pair_node_id.second)); + } + + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + Status GraphDataClient::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, std::shared_ptr *out) { #if !defined(_WIN32) && !defined(_WIN64) diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h index 0e8d08f11b..7d76d7fec7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h @@ -72,6 +72,13 @@ class GraphDataClient : public GraphData { // @return Status The status code returned Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) override; + // Get the edge id from connected node pair + // @param std::vector> node_list - List of pair nodes + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The status code that indicate the result of function execution + Status GetEdgesFromNodes(const std::vector> &node_list, + std::shared_ptr *out) override; + // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc index 70a090bb97..1264c68b6a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc @@ -128,6 +128,30 @@ Status GraphDataImpl::GetNodesFromEdges(const std::vector &edge_list return Status::OK(); } +Status GraphDataImpl::GetEdgesFromNodes(const std::vector> &node_list, + std::shared_ptr *out) { + if (node_list.empty()) { + RETURN_STATUS_UNEXPECTED("Input node list is empty."); + } + + std::vector> edge_list; + edge_list.reserve(node_list.size()); + + for (const auto &node_id : node_list) { + std::shared_ptr src_node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_id.first, &src_node)); + + EdgeIdType *edge_id = nullptr; + src_node->GetEdgeByAdjNodeId(node_id.second, &edge_id); + + std::vector connection_edge = {*edge_id}; + edge_list.emplace_back(std::move(connection_edge)); + } + + RETURN_IF_NOT_OK(CreateTensorByVector(edge_list, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + Status GraphDataImpl::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h index b5db50768b..2437a52a07 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h @@ -66,6 +66,13 @@ class GraphDataImpl : public GraphData { // @return Status The status code returned Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) override; + // Get the edge id from connected node pair + // @param std::vector> node_list - List of pair nodes + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The status code that indicate the result of function execution + Status GetEdgesFromNodes(const std::vector> &node_list, + std::shared_ptr *out) override; + // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc index 04b930bb55..df79e37e50 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include "minddata/dataset/engine/gnn/tensor_proto.h" @@ -31,6 +32,7 @@ static std::unordered_map g_get_graph_data_func_ = { {GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes}, {GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges}, {GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges}, + {GET_EDGES_FROM_NODES, &GraphDataServiceImpl::GetEdgesFromNodes}, {GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors}, {GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors}, {GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors}, @@ -189,6 +191,27 @@ Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *requ return Status::OK(); } +Status GraphDataServiceImpl::GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->node_pair_size() > 0, "The input node pair id list is empty."); + + std::vector> node_list; + node_list.resize(request->node_pair().size()); + + std::transform( + request->node_pair().begin(), request->node_pair().end(), node_list.begin(), [](const auto &node_pair_id) { + auto cur_pair = + std::make_pair(static_cast(node_pair_id.src_id()), static_cast(node_pair_id.dst_id())); + return cur_pair; + }); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetEdgesFromNodes(node_list, &tensor)); + + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h index 74996ccae4..3ed6fd386e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h @@ -50,6 +50,7 @@ class GraphDataServiceImpl { Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc index 16dcfa4d3a..b8d1e1363b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -61,10 +61,14 @@ Status GraphLoader::GetNodesAndEdges() { std::pair, std::shared_ptr> p; RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); + CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); + RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight())); + RETURN_IF_NOT_OK(src_itr->second->AddAdjacent(dst_itr->second, edge_ptr)); + e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); dq.pop_front(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc index bd7114f571..db0bec3394 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc @@ -131,6 +131,26 @@ Status LocalNode::AddNeighbor(const std::shared_ptr &node, const WeightTyp return Status::OK(); } +Status LocalNode::AddAdjacent(const std::shared_ptr &node, const std::shared_ptr &edge) { + auto node_id = node->id(); + auto edge_id = edge->id(); + adjacent_nodes_.insert({node_id, edge_id}); + return Status::OK(); +} + +Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) { + auto itr = adjacent_nodes_.find(adj_node_id); + + if (itr != adjacent_nodes_.end()) { + (*out_edge_id) = &(itr->second); + } else { + (*out_edge_id) = new EdgeIdType(-1); + MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node."; + } + + return Status::OK(); +} + Status LocalNode::UpdateFeature(const std::shared_ptr &feature) { auto itr = features_.find(feature->type()); if (itr != features_.end()) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h index 7ec030556a..c424848741 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h @@ -65,6 +65,18 @@ class LocalNode : public Node { // @return Status The status code returned Status AddNeighbor(const std::shared_ptr &node, const WeightType &) override; + // Add adjacent node and relative edge for source node + // @param std::shared_ptr node - the node to be inserted into adjacent table + // @param std::shared_ptr edge - the edge related to the adjacent node of source node + // @return Status - The status code that indicate the result of function execution + Status AddAdjacent(const std::shared_ptr &node, const std::shared_ptr &edge) override; + + // Get relative connecting edge of adjacent node by node id + // @param NodeIdType - The id of adjacent node to be processed + // @param std::shared_ptr - The id of relative connecting edge + // @return Status - The status code that indicate the result of function execution + Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) override; + // Update feature of node // @param std::shared_ptr feature - // @return Status The status code returned @@ -81,6 +93,7 @@ class LocalNode : public Node { std::mt19937 rnd_; std::unordered_map> features_; std::unordered_map>, std::vector>> neighbor_nodes_; + std::unordered_map adjacent_nodes_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h index 3382df5c24..f5a33c4948 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h @@ -29,9 +29,12 @@ namespace gnn { using NodeType = int8_t; using NodeIdType = int32_t; using WeightType = float; +using EdgeIdType = int32_t; constexpr NodeIdType kDefaultNodeId = -1; +class Edge; + class Node { public: // Constructor @@ -78,6 +81,18 @@ class Node { // @return Status The status code returned virtual Status AddNeighbor(const std::shared_ptr &node, const WeightType &weight) = 0; + // Add adjacent node and relative edge for source node + // @param std::shared_ptr node - the node to be inserted into adjacent table + // @param std::shared_ptr edge - the edge related to the adjacent node of source node + // @return Status - The status code that indicate the result of function execution + virtual Status AddAdjacent(const std::shared_ptr &node, const std::shared_ptr &edge) = 0; + + // Get relative connecting edge of adjacent node by node id + // @param NodeIdType - The id of adjacent node to be processed + // @param std::shared_ptr - The id of relative connecting edge + // @return Status - The status code that indicate the result of function execution + virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) = 0; + // Update feature of node // @param std::shared_ptr feature - // @return Status The status code returned diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc index 630e0f3fe4..b3a6cbda30 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_png_op.cc @@ -102,8 +102,7 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr &input, std::share unsigned char *ret_ptr = data.get(); std::shared_ptr DecodeOut(process.Get_Decode_DeviceData()); dsize_t dvpp_length = DecodeOut->dataSize; - // dsize_t decode_height = DecodeOut->height; - // dsize_t decode_width = DecodeOut->width; + const TensorShape dvpp_shape({dvpp_length, 1, 1}); const DataType dvpp_data_type(DataType::DE_UINT8); mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 80b4cb20b3..33632c6616 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -391,6 +391,38 @@ def validate_dataset_param_value(param_list, param_dict, param_type): type_check(param_dict.get(param_name), (param_type,), param_name) +def check_gnn_list_of_pair_or_ndarray(param, param_name): + """ + Check if the input parameter is a list of tuple or numpy.ndarray. + + Args: + param (Union[list[tuple], nd.ndarray]): param. + param_name (str): param_name. + + Returns: + Exception: TypeError if error. + """ + type_check(param, (list, np.ndarray), param_name) + if isinstance(param, list): + param_names = ["pair_{0}".format(i) for i in range(len(param))] + type_check_list(param, (tuple,), param_names) + for idx, pair in enumerate(param): + if not len(pair) == 2: + raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( + param_names[idx], len(pair))) + column_names = ["element_{0}".format(i) for i in range(len(pair))] + type_check_list(pair, (int,), column_names) + elif isinstance(param, np.ndarray): + if param.ndim != 2: + raise ValueError("Input ndarray must be in dimension 2. Got {0}".format(param.ndim)) + if param.shape[1] != 2: + raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( + param_name, param.shape[1])) + if not param.dtype == np.int32: + raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( + param_name, param.dtype)) + + def check_gnn_list_or_ndarray(param, param_name): """ Check if the input parameter is list or numpy.ndarray. diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index f0ff2dba4a..8f645a9eda 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -26,9 +26,9 @@ from mindspore._c_dataengine import Tensor from mindspore._c_dataengine import SamplingStrategy as Sampling from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ - check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ - check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \ - check_gnn_random_walk + check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \ + check_gnn_get_sampled_neighbors, check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, \ + check_gnn_get_edge_feature, check_gnn_random_walk class SamplingStrategy(IntEnum): @@ -163,6 +163,27 @@ class GraphData: raise Exception("This method is not supported when working mode is server.") return self._graph_data.get_nodes_from_edges(edge_list).as_array() + @check_gnn_get_edges_from_nodes + def get_edges_from_nodes(self, node_list): + """ + Get edges from the nodes. + + Args: + node_list (Union[list[tuple], numpy.ndarray]): The given list of pair nodes ID. + + Returns: + numpy.ndarray, array of edgs ID. + + Examples: + >>> edges = graph_dataset.get_edges_from_nodes([(1, 3), (5, 2)]) + + Raises: + TypeError: If `edge_list` is not list or ndarray. + """ + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server.") + return self._graph_data.get_edges_from_nodes(node_list).as_array() + @check_gnn_get_all_neighbors def get_all_neighbors(self, node_list, neighbor_type): """ diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a83e5b77bf..f06adbebcd 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -25,8 +25,8 @@ import numpy as np from mindspore._c_expression import typing from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ - validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ - check_columns, check_pos_int32, check_valid_str + validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \ + check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str from . import datasets from . import samplers @@ -1090,6 +1090,19 @@ def check_gnn_get_nodes_from_edges(method): return new_method +def check_gnn_get_edges_from_nodes(method): + """A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [node_list], _ = parse_user_args(method, *args, **kwargs) + check_gnn_list_of_pair_or_ndarray(node_list, "node_list") + + return method(self, *args, **kwargs) + + return new_method + + def check_gnn_get_all_neighbors(method): """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function.""" diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 3a80613b9c..81990b972b 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -95,6 +95,21 @@ class MindDataTestGNNGraph : public UT::Common { } }; +TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) { + std::string path = "data/mindrecord/testGraphData/testdata"; + GraphDataImpl graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + std::vector> src_dst_list = {{101, 201}, {103, 207}, {108, 208}, + {110, 201}, {204, 105}, {208, 108}}; + std::shared_ptr edges; + s = graph.GetEdgesFromNodes(src_dst_list, &edges); + + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]"); +} + TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { std::string path = "data/mindrecord/testGraphData/testdata"; GraphDataImpl graph(path, 1); diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 83f84dc7b2..3e55d47c92 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -241,6 +241,18 @@ def test_graphdata_getedgefeature(): assert features[1].shape == (40,) +def test_graphdata_getedgesfromnodes(): + """ + Test get edges from nodes + """ + logger.info('test get_edges_from_nodes\n') + g = ds.GraphData(DATASET_FILE) + + nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)] + edges = g.get_edges_from_nodes(nodes_pair_list) + assert edges.tolist() == [1, 9, 31, 17, 20, 40] + + if __name__ == '__main__': test_graphdata_getfullneighbor() test_graphdata_getnodefeature_input_check() @@ -251,3 +263,4 @@ if __name__ == '__main__': test_graphdata_randomwalkdefault() test_graphdata_randomwalk() test_graphdata_getedgefeature() + test_graphdata_getedgesfromnodes() diff --git a/tests/ut/python/dataset/test_graphdata_distributed.py b/tests/ut/python/dataset/test_graphdata_distributed.py index 97b4b8e137..22c8c6fac4 100644 --- a/tests/ut/python/dataset/test_graphdata_distributed.py +++ b/tests/ut/python/dataset/test_graphdata_distributed.py @@ -112,6 +112,10 @@ def test_graphdata_distributed(): assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] + nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)] + edges = g.get_edges_from_nodes(nodes_pair_list) + assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37] + batch_num = 2 edge_num = g.graph_info()['edge_num'][0] out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]