From: @lizhenglong1992 Reviewed-by: Signed-off-by:pull/14874/MERGE
| @@ -57,6 +57,12 @@ PYBIND_REGISTER( | |||||
| THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_edges_from_nodes", | |||||
| [](gnn::GraphData &g, std::vector<std::pair<gnn::NodeIdType, gnn::NodeIdType>> node_list) { | |||||
| std::shared_ptr<Tensor> out; | |||||
| THROW_IF_ERROR(g.GetEdgesFromNodes(node_list, &out)); | |||||
| return out; | |||||
| }) | |||||
| .def("get_all_neighbors", | .def("get_all_neighbors", | ||||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { | [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { | ||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| @@ -50,12 +50,13 @@ enum GnnOpName { | |||||
| GET_ALL_NODES = 0; | GET_ALL_NODES = 0; | ||||
| GET_ALL_EDGES = 1; | GET_ALL_EDGES = 1; | ||||
| GET_NODES_FROM_EDGES = 2; | GET_NODES_FROM_EDGES = 2; | ||||
| GET_ALL_NEIGHBORS = 3; | |||||
| GET_SAMPLED_NEIGHBORS = 4; | |||||
| GET_NEG_SAMPLED_NEIGHBORS = 5; | |||||
| RANDOM_WALK = 6; | |||||
| GET_NODE_FEATURE = 7; | |||||
| GET_EDGE_FEATURE = 8; | |||||
| GET_EDGES_FROM_NODES = 3; | |||||
| GET_ALL_NEIGHBORS = 4; | |||||
| GET_SAMPLED_NEIGHBORS = 5; | |||||
| GET_NEG_SAMPLED_NEIGHBORS = 6; | |||||
| RANDOM_WALK = 7; | |||||
| GET_NODE_FEATURE = 8; | |||||
| GET_EDGE_FEATURE = 9; | |||||
| } | } | ||||
| message GnnRandomWalkPb { | message GnnRandomWalkPb { | ||||
| @@ -64,6 +65,11 @@ message GnnRandomWalkPb { | |||||
| int32 default_id = 3; | int32 default_id = 3; | ||||
| } | } | ||||
| message IdPairPb { | |||||
| int32 src_id = 1; | |||||
| int32 dst_id = 2; | |||||
| } | |||||
| message GnnGraphDataRequestPb { | message GnnGraphDataRequestPb { | ||||
| GnnOpName op_name = 1; | GnnOpName op_name = 1; | ||||
| repeated int32 id = 2; // node id or edge id | repeated int32 id = 2; // node id or edge id | ||||
| @@ -72,6 +78,7 @@ message GnnGraphDataRequestPb { | |||||
| TensorPb id_tensor = 5; // input ids ,node id or edge id | TensorPb id_tensor = 5; // input ids ,node id or edge id | ||||
| GnnRandomWalkPb random_walk = 6; | GnnRandomWalkPb random_walk = 6; | ||||
| int32 strategy = 7; | int32 strategy = 7; | ||||
| repeated IdPairPb node_pair = 8; | |||||
| } | } | ||||
| message GnnGraphDataResponsePb { | message GnnGraphDataResponsePb { | ||||
| @@ -62,6 +62,13 @@ class GraphData { | |||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | ||||
| // Get the edge id from connected node pair | |||||
| // @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes | |||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||||
| // @return Status - The status code that indicate the result of function execution | |||||
| virtual Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||||
| std::shared_ptr<Tensor> *out) = 0; | |||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| // @param std::vector<NodeType> node_list - List of nodes | // @param std::vector<NodeType> node_list - List of nodes | ||||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | ||||
| @@ -120,6 +120,25 @@ Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_li | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status GraphDataClient::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_EDGES_FROM_NODES); | |||||
| for (const auto &pair_node_id : node_list) { | |||||
| IdPairPb *proto_pair(request.add_node_pair()); | |||||
| proto_pair->set_src_id(static_cast<google::protobuf::int32>(pair_node_id.first)); | |||||
| proto_pair->set_dst_id(static_cast<google::protobuf::int32>(pair_node_id.second)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out) { | std::shared_ptr<Tensor> *out) { | ||||
| #if !defined(_WIN32) && !defined(_WIN64) | #if !defined(_WIN32) && !defined(_WIN64) | ||||
| @@ -72,6 +72,13 @@ class GraphDataClient : public GraphData { | |||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | ||||
| // Get the edge id from connected node pair | |||||
| // @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes | |||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||||
| // @return Status - The status code that indicate the result of function execution | |||||
| Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||||
| std::shared_ptr<Tensor> *out) override; | |||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| // @param std::vector<NodeType> node_list - List of nodes | // @param std::vector<NodeType> node_list - List of nodes | ||||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | ||||
| @@ -128,6 +128,30 @@ Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| if (node_list.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Input node list is empty."); | |||||
| } | |||||
| std::vector<std::vector<EdgeIdType>> edge_list; | |||||
| edge_list.reserve(node_list.size()); | |||||
| for (const auto &node_id : node_list) { | |||||
| std::shared_ptr<Node> src_node; | |||||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_id.first, &src_node)); | |||||
| EdgeIdType *edge_id = nullptr; | |||||
| src_node->GetEdgeByAdjNodeId(node_id.second, &edge_id); | |||||
| std::vector<EdgeIdType> connection_edge = {*edge_id}; | |||||
| edge_list.emplace_back(std::move(connection_edge)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>(edge_list, DataType(DataType::DE_INT32), out)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out) { | std::shared_ptr<Tensor> *out) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | ||||
| @@ -66,6 +66,13 @@ class GraphDataImpl : public GraphData { | |||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | ||||
| // Get the edge id from connected node pair | |||||
| // @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes | |||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||||
| // @return Status - The status code that indicate the result of function execution | |||||
| Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list, | |||||
| std::shared_ptr<Tensor> *out) override; | |||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| // @param std::vector<NodeType> node_list - List of nodes | // @param std::vector<NodeType> node_list - List of nodes | ||||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | #include "minddata/dataset/engine/gnn/tensor_proto.h" | ||||
| @@ -31,6 +32,7 @@ static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = { | |||||
| {GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes}, | {GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes}, | ||||
| {GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges}, | {GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges}, | ||||
| {GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges}, | {GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges}, | ||||
| {GET_EDGES_FROM_NODES, &GraphDataServiceImpl::GetEdgesFromNodes}, | |||||
| {GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors}, | {GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors}, | ||||
| {GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors}, | {GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors}, | ||||
| {GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors}, | {GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors}, | ||||
| @@ -189,6 +191,27 @@ Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *requ | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status GraphDataServiceImpl::GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->node_pair_size() > 0, "The input node pair id list is empty."); | |||||
| std::vector<std::pair<NodeIdType, NodeIdType>> node_list; | |||||
| node_list.resize(request->node_pair().size()); | |||||
| std::transform( | |||||
| request->node_pair().begin(), request->node_pair().end(), node_list.begin(), [](const auto &node_pair_id) { | |||||
| auto cur_pair = | |||||
| std::make_pair(static_cast<NodeIdType>(node_pair_id.src_id()), static_cast<NodeIdType>(node_pair_id.dst_id())); | |||||
| return cur_pair; | |||||
| }); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetEdgesFromNodes(node_list, &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | ||||
| @@ -50,6 +50,7 @@ class GraphDataServiceImpl { | |||||
| Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | ||||
| Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | ||||
| Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | ||||
| Status GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | ||||
| Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | ||||
| Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | ||||
| @@ -61,10 +61,14 @@ Status GraphLoader::GetNodesAndEdges() { | |||||
| std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p; | std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p; | ||||
| RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); | RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); | ||||
| auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); | auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); | CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); | CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); | ||||
| RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); | RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); | ||||
| RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight())); | RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight())); | ||||
| RETURN_IF_NOT_OK(src_itr->second->AddAdjacent(dst_itr->second, edge_ptr)); | |||||
| e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ | e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ | ||||
| graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); | graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); | ||||
| dq.pop_front(); | dq.pop_front(); | ||||
| @@ -131,6 +131,26 @@ Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightTyp | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) { | |||||
| auto node_id = node->id(); | |||||
| auto edge_id = edge->id(); | |||||
| adjacent_nodes_.insert({node_id, edge_id}); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) { | |||||
| auto itr = adjacent_nodes_.find(adj_node_id); | |||||
| if (itr != adjacent_nodes_.end()) { | |||||
| (*out_edge_id) = &(itr->second); | |||||
| } else { | |||||
| (*out_edge_id) = new EdgeIdType(-1); | |||||
| MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node."; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) { | Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) { | ||||
| auto itr = features_.find(feature->type()); | auto itr = features_.find(feature->type()); | ||||
| if (itr != features_.end()) { | if (itr != features_.end()) { | ||||
| @@ -65,6 +65,18 @@ class LocalNode : public Node { | |||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override; | Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override; | ||||
| // Add adjacent node and relative edge for source node | |||||
| // @param std::shared_ptr<Node> node - the node to be inserted into adjacent table | |||||
| // @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node | |||||
| // @return Status - The status code that indicate the result of function execution | |||||
| Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) override; | |||||
| // Get relative connecting edge of adjacent node by node id | |||||
| // @param NodeIdType - The id of adjacent node to be processed | |||||
| // @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge | |||||
| // @return Status - The status code that indicate the result of function execution | |||||
| Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) override; | |||||
| // Update feature of node | // Update feature of node | ||||
| // @param std::shared_ptr<Feature> feature - | // @param std::shared_ptr<Feature> feature - | ||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| @@ -81,6 +93,7 @@ class LocalNode : public Node { | |||||
| std::mt19937 rnd_; | std::mt19937 rnd_; | ||||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; | std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; | ||||
| std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_; | std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_; | ||||
| std::unordered_map<NodeIdType, EdgeIdType> adjacent_nodes_; | |||||
| }; | }; | ||||
| } // namespace gnn | } // namespace gnn | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -29,9 +29,12 @@ namespace gnn { | |||||
| using NodeType = int8_t; | using NodeType = int8_t; | ||||
| using NodeIdType = int32_t; | using NodeIdType = int32_t; | ||||
| using WeightType = float; | using WeightType = float; | ||||
| using EdgeIdType = int32_t; | |||||
| constexpr NodeIdType kDefaultNodeId = -1; | constexpr NodeIdType kDefaultNodeId = -1; | ||||
| class Edge; | |||||
| class Node { | class Node { | ||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| @@ -78,6 +81,18 @@ class Node { | |||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0; | virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0; | ||||
| // Add adjacent node and relative edge for source node | |||||
| // @param std::shared_ptr<Node> node - the node to be inserted into adjacent table | |||||
| // @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node | |||||
| // @return Status - The status code that indicate the result of function execution | |||||
| virtual Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) = 0; | |||||
| // Get relative connecting edge of adjacent node by node id | |||||
| // @param NodeIdType - The id of adjacent node to be processed | |||||
| // @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge | |||||
| // @return Status - The status code that indicate the result of function execution | |||||
| virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) = 0; | |||||
| // Update feature of node | // Update feature of node | ||||
| // @param std::shared_ptr<Feature> feature - | // @param std::shared_ptr<Feature> feature - | ||||
| // @return Status The status code returned | // @return Status The status code returned | ||||
| @@ -102,8 +102,7 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr<Tensor> &input, std::share | |||||
| unsigned char *ret_ptr = data.get(); | unsigned char *ret_ptr = data.get(); | ||||
| std::shared_ptr<DvppDataInfo> DecodeOut(process.Get_Decode_DeviceData()); | std::shared_ptr<DvppDataInfo> DecodeOut(process.Get_Decode_DeviceData()); | ||||
| dsize_t dvpp_length = DecodeOut->dataSize; | dsize_t dvpp_length = DecodeOut->dataSize; | ||||
| // dsize_t decode_height = DecodeOut->height; | |||||
| // dsize_t decode_width = DecodeOut->width; | |||||
| const TensorShape dvpp_shape({dvpp_length, 1, 1}); | const TensorShape dvpp_shape({dvpp_length, 1, 1}); | ||||
| const DataType dvpp_data_type(DataType::DE_UINT8); | const DataType dvpp_data_type(DataType::DE_UINT8); | ||||
| mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); | mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); | ||||
| @@ -391,6 +391,38 @@ def validate_dataset_param_value(param_list, param_dict, param_type): | |||||
| type_check(param_dict.get(param_name), (param_type,), param_name) | type_check(param_dict.get(param_name), (param_type,), param_name) | ||||
| def check_gnn_list_of_pair_or_ndarray(param, param_name): | |||||
| """ | |||||
| Check if the input parameter is a list of tuple or numpy.ndarray. | |||||
| Args: | |||||
| param (Union[list[tuple], nd.ndarray]): param. | |||||
| param_name (str): param_name. | |||||
| Returns: | |||||
| Exception: TypeError if error. | |||||
| """ | |||||
| type_check(param, (list, np.ndarray), param_name) | |||||
| if isinstance(param, list): | |||||
| param_names = ["pair_{0}".format(i) for i in range(len(param))] | |||||
| type_check_list(param, (tuple,), param_names) | |||||
| for idx, pair in enumerate(param): | |||||
| if not len(pair) == 2: | |||||
| raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( | |||||
| param_names[idx], len(pair))) | |||||
| column_names = ["element_{0}".format(i) for i in range(len(pair))] | |||||
| type_check_list(pair, (int,), column_names) | |||||
| elif isinstance(param, np.ndarray): | |||||
| if param.ndim != 2: | |||||
| raise ValueError("Input ndarray must be in dimension 2. Got {0}".format(param.ndim)) | |||||
| if param.shape[1] != 2: | |||||
| raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( | |||||
| param_name, param.shape[1])) | |||||
| if not param.dtype == np.int32: | |||||
| raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( | |||||
| param_name, param.dtype)) | |||||
| def check_gnn_list_or_ndarray(param, param_name): | def check_gnn_list_or_ndarray(param, param_name): | ||||
| """ | """ | ||||
| Check if the input parameter is list or numpy.ndarray. | Check if the input parameter is list or numpy.ndarray. | ||||
| @@ -26,9 +26,9 @@ from mindspore._c_dataengine import Tensor | |||||
| from mindspore._c_dataengine import SamplingStrategy as Sampling | from mindspore._c_dataengine import SamplingStrategy as Sampling | ||||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | ||||
| check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | |||||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \ | |||||
| check_gnn_random_walk | |||||
| check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \ | |||||
| check_gnn_get_sampled_neighbors, check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, \ | |||||
| check_gnn_get_edge_feature, check_gnn_random_walk | |||||
| class SamplingStrategy(IntEnum): | class SamplingStrategy(IntEnum): | ||||
| @@ -163,6 +163,27 @@ class GraphData: | |||||
| raise Exception("This method is not supported when working mode is server.") | raise Exception("This method is not supported when working mode is server.") | ||||
| return self._graph_data.get_nodes_from_edges(edge_list).as_array() | return self._graph_data.get_nodes_from_edges(edge_list).as_array() | ||||
| @check_gnn_get_edges_from_nodes | |||||
| def get_edges_from_nodes(self, node_list): | |||||
| """ | |||||
| Get edges from the nodes. | |||||
| Args: | |||||
| node_list (Union[list[tuple], numpy.ndarray]): The given list of pair nodes ID. | |||||
| Returns: | |||||
| numpy.ndarray, array of edgs ID. | |||||
| Examples: | |||||
| >>> edges = graph_dataset.get_edges_from_nodes([(1, 3), (5, 2)]) | |||||
| Raises: | |||||
| TypeError: If `edge_list` is not list or ndarray. | |||||
| """ | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server.") | |||||
| return self._graph_data.get_edges_from_nodes(node_list).as_array() | |||||
| @check_gnn_get_all_neighbors | @check_gnn_get_all_neighbors | ||||
| def get_all_neighbors(self, node_list, neighbor_type): | def get_all_neighbors(self, node_list, neighbor_type): | ||||
| """ | """ | ||||
| @@ -25,8 +25,8 @@ import numpy as np | |||||
| from mindspore._c_expression import typing | from mindspore._c_expression import typing | ||||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | ||||
| INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | ||||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | |||||
| check_columns, check_pos_int32, check_valid_str | |||||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \ | |||||
| check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str | |||||
| from . import datasets | from . import datasets | ||||
| from . import samplers | from . import samplers | ||||
| @@ -1090,6 +1090,19 @@ def check_gnn_get_nodes_from_edges(method): | |||||
| return new_method | return new_method | ||||
| def check_gnn_get_edges_from_nodes(method): | |||||
| """A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [node_list], _ = parse_user_args(method, *args, **kwargs) | |||||
| check_gnn_list_of_pair_or_ndarray(node_list, "node_list") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_gnn_get_all_neighbors(method): | def check_gnn_get_all_neighbors(method): | ||||
| """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function.""" | """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function.""" | ||||
| @@ -95,6 +95,21 @@ class MindDataTestGNNGraph : public UT::Common { | |||||
| } | } | ||||
| }; | }; | ||||
| TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) { | |||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||||
| GraphDataImpl graph(path, 1); | |||||
| Status s = graph.Init(); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208}, | |||||
| {110, 201}, {204, 105}, {208, 108}}; | |||||
| std::shared_ptr<Tensor> edges; | |||||
| s = graph.GetEdgesFromNodes(src_dst_list, &edges); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]"); | |||||
| } | |||||
| TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | ||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | std::string path = "data/mindrecord/testGraphData/testdata"; | ||||
| GraphDataImpl graph(path, 1); | GraphDataImpl graph(path, 1); | ||||
| @@ -241,6 +241,18 @@ def test_graphdata_getedgefeature(): | |||||
| assert features[1].shape == (40,) | assert features[1].shape == (40,) | ||||
| def test_graphdata_getedgesfromnodes(): | |||||
| """ | |||||
| Test get edges from nodes | |||||
| """ | |||||
| logger.info('test get_edges_from_nodes\n') | |||||
| g = ds.GraphData(DATASET_FILE) | |||||
| nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)] | |||||
| edges = g.get_edges_from_nodes(nodes_pair_list) | |||||
| assert edges.tolist() == [1, 9, 31, 17, 20, 40] | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_graphdata_getfullneighbor() | test_graphdata_getfullneighbor() | ||||
| test_graphdata_getnodefeature_input_check() | test_graphdata_getnodefeature_input_check() | ||||
| @@ -251,3 +263,4 @@ if __name__ == '__main__': | |||||
| test_graphdata_randomwalkdefault() | test_graphdata_randomwalkdefault() | ||||
| test_graphdata_randomwalk() | test_graphdata_randomwalk() | ||||
| test_graphdata_getedgefeature() | test_graphdata_getedgefeature() | ||||
| test_graphdata_getedgesfromnodes() | |||||
| @@ -112,6 +112,10 @@ def test_graphdata_distributed(): | |||||
| assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, | assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, | ||||
| 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] | 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] | ||||
| nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)] | |||||
| edges = g.get_edges_from_nodes(nodes_pair_list) | |||||
| assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37] | |||||
| batch_num = 2 | batch_num = 2 | ||||
| edge_num = g.graph_info()['edge_num'][0] | edge_num = g.graph_info()['edge_num'][0] | ||||
| out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] | out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] | ||||