2. mod cora and citeseer conversion scripttags/v0.5.0-beta
| @@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t | |||
| 1. Download and prepare the Cora dataset as required. | |||
| > [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora) | |||
| 2. Edit write_cora.sh and modify the parameters | |||
| ``` | |||
| --mindrecord_file: output MindRecord file. | |||
| @@ -15,29 +15,26 @@ | |||
| """ | |||
| User-defined API for MindRecord GNN writer. | |||
| """ | |||
| import csv | |||
| import os | |||
| import pickle as pkl | |||
| import numpy as np | |||
| import scipy.sparse as sp | |||
| # parse args from command line parameter 'graph_api_args' | |||
| # args delimiter is ':' | |||
| args = os.environ['graph_api_args'].split(':') | |||
| CITESEER_CONTENT_FILE = args[0] | |||
| CITESEER_CITES_FILE = args[1] | |||
| CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord" | |||
| CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord" | |||
| node_id_map = {} | |||
| CITESEER_PATH = args[0] | |||
| dataset_str = 'citeseer' | |||
| # profile: (num_features, feature_data_types, feature_shapes) | |||
| node_profile = (2, ["float32", "int64"], [[-1], [-1]]) | |||
| node_profile = (2, ["float32", "int32"], [[-1], [-1]]) | |||
| edge_profile = (0, [], []) | |||
| node_ids = [] | |||
| def _normalize_citeseer_features(features): | |||
| features = np.array(features) | |||
| row_sum = np.array(features.sum(1)) | |||
| r_inv = np.power(row_sum * 1.0, -1).flatten() | |||
| r_inv[np.isinf(r_inv)] = 0. | |||
| @@ -46,6 +43,14 @@ def _normalize_citeseer_features(features): | |||
| return features | |||
| def _parse_index_file(filename): | |||
| """Parse index file.""" | |||
| index = [] | |||
| for line in open(filename): | |||
| index.append(int(line.strip())) | |||
| return index | |||
| def yield_nodes(task_id=0): | |||
| """ | |||
| Generate node data | |||
| @@ -54,29 +59,46 @@ def yield_nodes(task_id=0): | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Node task is {}".format(task_id)) | |||
| label_types = {} | |||
| label_size = 0 | |||
| node_num = 0 | |||
| with open(CITESEER_CONTENT_FILE) as content_file: | |||
| content_reader = csv.reader(content_file, delimiter='\t') | |||
| line_count = 0 | |||
| for row in content_reader: | |||
| if not row[-1] in label_types: | |||
| label_types[row[-1]] = label_size | |||
| label_size += 1 | |||
| if not row[0] in node_id_map: | |||
| node_id_map[row[0]] = node_num | |||
| node_num += 1 | |||
| raw_features = [[int(x) for x in row[1:-1]]] | |||
| node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features), | |||
| 'feature_2': [label_types[row[-1]]]} | |||
| yield node | |||
| line_count += 1 | |||
| names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] | |||
| objects = [] | |||
| for name in names: | |||
| with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f: | |||
| objects.append(pkl.load(f, encoding='latin1')) | |||
| x, y, tx, ty, allx, ally = tuple(objects) | |||
| test_idx_reorder = _parse_index_file( | |||
| "{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str)) | |||
| test_idx_range = np.sort(test_idx_reorder) | |||
| tx = _normalize_citeseer_features(tx) | |||
| allx = _normalize_citeseer_features(allx) | |||
| # Fix citeseer dataset (there are some isolated nodes in the graph) | |||
| # Find isolated nodes, add them as zero-vecs into the right position | |||
| test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) | |||
| tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) | |||
| tx_extended[test_idx_range-min(test_idx_range), :] = tx | |||
| tx = tx_extended | |||
| ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) | |||
| ty_extended[test_idx_range-min(test_idx_range), :] = ty | |||
| ty = ty_extended | |||
| features = sp.vstack((allx, tx)).tolil() | |||
| features[test_idx_reorder, :] = features[test_idx_range, :] | |||
| features = features.A | |||
| labels = np.vstack((ally, ty)) | |||
| labels[test_idx_reorder, :] = labels[test_idx_range, :] | |||
| line_count = 0 | |||
| for i, label in enumerate(labels): | |||
| if not 1 in label.tolist(): | |||
| continue | |||
| node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), | |||
| 'feature_2': label.tolist().index(1)} | |||
| line_count += 1 | |||
| node_ids.append(i) | |||
| yield node | |||
| print('Processed {} lines for nodes.'.format(line_count)) | |||
| # print('label types {}.'.format(label_types)) | |||
| with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f: | |||
| for k in label_types: | |||
| print(k + ',' + str(label_types[k]), file=f) | |||
| def yield_edges(task_id=0): | |||
| @@ -87,23 +109,20 @@ def yield_edges(task_id=0): | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Edge task is {}".format(task_id)) | |||
| # print(map_string_int) | |||
| with open(CITESEER_CITES_FILE) as cites_file: | |||
| cites_reader = csv.reader(cites_file, delimiter='\t') | |||
| with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: | |||
| graph = pkl.load(f, encoding='latin1') | |||
| line_count = 0 | |||
| for row in cites_reader: | |||
| if not row[0] in node_id_map: | |||
| print('Source node {} does not exist.'.format(row[0])) | |||
| continue | |||
| if not row[1] in node_id_map: | |||
| print('Destination node {} does not exist.'.format(row[1])) | |||
| continue | |||
| line_count += 1 | |||
| edge = {'id': line_count, | |||
| 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} | |||
| yield edge | |||
| with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f: | |||
| for k in node_id_map: | |||
| print(k + ',' + str(node_id_map[k]), file=f) | |||
| for i in graph: | |||
| for dst_id in graph[i]: | |||
| if not i in node_ids: | |||
| print('Source node {} does not exist.'.format(i)) | |||
| continue | |||
| if not dst_id in node_ids: | |||
| print('Destination node {} does not exist.'.format( | |||
| dst_id)) | |||
| continue | |||
| edge = {'id': line_count, | |||
| 'src_id': i, 'dst_id': dst_id, 'type': 0} | |||
| line_count += 1 | |||
| yield edge | |||
| print('Processed {} lines for edges.'.format(line_count)) | |||
| @@ -15,29 +15,24 @@ | |||
| """ | |||
| User-defined API for MindRecord GNN writer. | |||
| """ | |||
| import csv | |||
| import os | |||
| import pickle as pkl | |||
| import numpy as np | |||
| import scipy.sparse as sp | |||
| # parse args from command line parameter 'graph_api_args' | |||
| # args delimiter is ':' | |||
| args = os.environ['graph_api_args'].split(':') | |||
| CORA_CONTENT_FILE = args[0] | |||
| CORA_CITES_FILE = args[1] | |||
| CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord" | |||
| CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord" | |||
| node_id_map = {} | |||
| CORA_PATH = args[0] | |||
| dataset_str = 'cora' | |||
| # profile: (num_features, feature_data_types, feature_shapes) | |||
| node_profile = (2, ["float32", "int64"], [[-1], [-1]]) | |||
| node_profile = (2, ["float32", "int32"], [[-1], [-1]]) | |||
| edge_profile = (0, [], []) | |||
| def _normalize_cora_features(features): | |||
| features = np.array(features) | |||
| row_sum = np.array(features.sum(1)) | |||
| r_inv = np.power(row_sum * 1.0, -1).flatten() | |||
| r_inv[np.isinf(r_inv)] = 0. | |||
| @@ -46,6 +41,14 @@ def _normalize_cora_features(features): | |||
| return features | |||
| def _parse_index_file(filename): | |||
| """Parse index file.""" | |||
| index = [] | |||
| for line in open(filename): | |||
| index.append(int(line.strip())) | |||
| return index | |||
| def yield_nodes(task_id=0): | |||
| """ | |||
| Generate node data | |||
| @@ -54,32 +57,32 @@ def yield_nodes(task_id=0): | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Node task is {}".format(task_id)) | |||
| label_types = {} | |||
| label_size = 0 | |||
| node_num = 0 | |||
| with open(CORA_CONTENT_FILE) as content_file: | |||
| content_reader = csv.reader(content_file, delimiter=',') | |||
| line_count = 0 | |||
| for row in content_reader: | |||
| if line_count == 0: | |||
| line_count += 1 | |||
| continue | |||
| if not row[0] in node_id_map: | |||
| node_id_map[row[0]] = node_num | |||
| node_num += 1 | |||
| if not row[-1] in label_types: | |||
| label_types[row[-1]] = label_size | |||
| label_size += 1 | |||
| raw_features = [[int(x) for x in row[1:-1]]] | |||
| node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features), | |||
| 'feature_2': [label_types[row[-1]]]} | |||
| yield node | |||
| line_count += 1 | |||
| names = ['tx', 'ty', 'allx', 'ally'] | |||
| objects = [] | |||
| for name in names: | |||
| with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f: | |||
| objects.append(pkl.load(f, encoding='latin1')) | |||
| tx, ty, allx, ally = tuple(objects) | |||
| test_idx_reorder = _parse_index_file( | |||
| "{}/ind.{}.test.index".format(CORA_PATH, dataset_str)) | |||
| test_idx_range = np.sort(test_idx_reorder) | |||
| features = sp.vstack((allx, tx)).tolil() | |||
| features[test_idx_reorder, :] = features[test_idx_range, :] | |||
| features = _normalize_cora_features(features) | |||
| features = features.A | |||
| labels = np.vstack((ally, ty)) | |||
| labels[test_idx_reorder, :] = labels[test_idx_range, :] | |||
| line_count = 0 | |||
| for i, label in enumerate(labels): | |||
| node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), | |||
| 'feature_2': label.tolist().index(1)} | |||
| line_count += 1 | |||
| yield node | |||
| print('Processed {} lines for nodes.'.format(line_count)) | |||
| print('label types {}.'.format(label_types)) | |||
| with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f: | |||
| for k in label_types: | |||
| print(k + ',' + str(label_types[k]), file=f) | |||
| def yield_edges(task_id=0): | |||
| @@ -90,24 +93,13 @@ def yield_edges(task_id=0): | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Edge task is {}".format(task_id)) | |||
| with open(CORA_CITES_FILE) as cites_file: | |||
| cites_reader = csv.reader(cites_file, delimiter=',') | |||
| with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f: | |||
| graph = pkl.load(f, encoding='latin1') | |||
| line_count = 0 | |||
| for row in cites_reader: | |||
| if line_count == 0: | |||
| for i in graph: | |||
| for dst_id in graph[i]: | |||
| edge = {'id': line_count, | |||
| 'src_id': i, 'dst_id': dst_id, 'type': 0} | |||
| line_count += 1 | |||
| continue | |||
| if not row[0] in node_id_map: | |||
| print('Source node {} does not exist.'.format(row[0])) | |||
| continue | |||
| if not row[1] in node_id_map: | |||
| print('Destination node {} does not exist.'.format(row[1])) | |||
| continue | |||
| edge = {'id': line_count, | |||
| 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} | |||
| yield edge | |||
| line_count += 1 | |||
| yield edge | |||
| print('Processed {} lines for edges.'.format(line_count)) | |||
| with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f: | |||
| for k in node_id_map: | |||
| print(k + ',' + str(node_id_map[k]), file=f) | |||
| @@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \ | |||
| --mindrecord_partitions 1 \ | |||
| --mindrecord_header_size_by_bit 18 \ | |||
| --mindrecord_page_size_by_bit 20 \ | |||
| --graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites" | |||
| --graph_api_args "$SRC_PATH" | |||
| @@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \ | |||
| --mindrecord_partitions 1 \ | |||
| --mindrecord_header_size_by_bit 18 \ | |||
| --mindrecord_page_size_by_bit 20 \ | |||
| --graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv" | |||
| --graph_api_args "$SRC_PATH" | |||
| @@ -527,10 +527,22 @@ void bindGraphData(py::module *m) { | |||
| THROW_IF_ERROR(g_out->Init()); | |||
| return g_out; | |||
| })) | |||
| .def("get_nodes", | |||
| [](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) { | |||
| .def("get_all_nodes", | |||
| [](gnn::Graph &g, gnn::NodeType node_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out)); | |||
| THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_all_edges", | |||
| [](gnn::Graph &g, gnn::EdgeType edge_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_nodes_from_edges", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_all_neighbors", | |||
| @@ -539,12 +551,31 @@ void bindGraphData(py::module *m) { | |||
| THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_sampled_neighbors", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, | |||
| std::vector<gnn::NodeType> neighbor_types) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_neg_sampled_neighbors", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, | |||
| gnn::NodeType neg_neighbor_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_node_feature", | |||
| [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { | |||
| TensorRow out; | |||
| THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | |||
| return out; | |||
| }); | |||
| }) | |||
| .def("graph_info", [](gnn::Graph &g) { | |||
| py::dict out; | |||
| THROW_IF_ERROR(g.GraphInfo(&out)); | |||
| return out; | |||
| }); | |||
| } | |||
| // This is where we externalize the C logic as python modules | |||
| @@ -17,29 +17,30 @@ | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <iterator> | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include "dataset/core/tensor_shape.h" | |||
| #include "dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| Graph::Graph(std::string dataset_file, int32_t num_workers) : dataset_file_(dataset_file), num_workers_(num_workers) { | |||
| Graph::Graph(std::string dataset_file, int32_t num_workers) | |||
| : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) { | |||
| rnd_.seed(GetSeed()); | |||
| MS_LOG(INFO) << "num_workers:" << num_workers; | |||
| } | |||
| Status Graph::GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out) { | |||
| Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||
| auto itr = node_type_map_.find(node_type); | |||
| if (itr == node_type_map_.end()) { | |||
| std::string err_msg = "Invalid node type:" + std::to_string(node_type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| if (node_num == -1) { | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out)); | |||
| } else { | |||
| } | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -59,9 +60,9 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data | |||
| RETURN_IF_NOT_OK(Tensor::CreateTensor( | |||
| &tensor, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(m), static_cast<dsize_t>(n)}), type, nullptr)); | |||
| T *ptr = reinterpret_cast<T *>(tensor->GetMutableBuffer()); | |||
| for (auto id_m : data) { | |||
| for (const auto &id_m : data) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); | |||
| for (auto id_n : id_m) { | |||
| for (const auto &id_n : id_m) { | |||
| *ptr = id_n; | |||
| ptr++; | |||
| } | |||
| @@ -89,7 +90,38 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out) { return Status::OK(); } | |||
| Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||
| auto itr = edge_type_map_.find(edge_type); | |||
| if (itr == edge_type_map_.end()) { | |||
| std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>({itr->second}, DataType(DataType::DE_INT32), out)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||
| if (edge_list.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); | |||
| } | |||
| std::vector<std::vector<NodeIdType>> node_list; | |||
| node_list.reserve(edge_list.size()); | |||
| for (const auto &edge_id : edge_list) { | |||
| auto itr = edge_id_map_.find(edge_id); | |||
| if (itr == edge_id_map_.end()) { | |||
| std::string err_msg = "Invalid edge id:" + std::to_string(edge_id); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> nodes; | |||
| RETURN_IF_NOT_OK(itr->second->GetNode(&nodes)); | |||
| node_list.push_back({nodes.first->id(), nodes.second->id()}); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(node_list, DataType(DataType::DE_INT32), out)); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| @@ -105,14 +137,10 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType | |||
| size_t max_neighbor_num = 0; | |||
| neighbors.resize(node_list.size()); | |||
| for (size_t i = 0; i < node_list.size(); ++i) { | |||
| auto itr = node_id_map_.find(node_list[i]); | |||
| if (itr != node_id_map_.end()) { | |||
| RETURN_IF_NOT_OK(itr->second->GetNeighbors(neighbor_type, -1, &neighbors[i])); | |||
| max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); | |||
| } else { | |||
| std::string err_msg = "Invalid node id:" + std::to_string(node_list[i]); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::shared_ptr<Node> node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); | |||
| RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); | |||
| max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); | |||
| } | |||
| RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId)); | |||
| @@ -121,13 +149,94 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), | |||
| "The sizes of neighbor_nums and neighbor_types are inconsistent."); | |||
| std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size()); | |||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | |||
| neighbors_vec[node_idx].emplace_back(node_list[node_idx]); | |||
| std::vector<NodeIdType> input_list = {node_list[node_idx]}; | |||
| for (size_t i = 0; i < neighbor_nums.size(); ++i) { | |||
| std::vector<NodeIdType> neighbors; | |||
| neighbors.reserve(input_list.size() * neighbor_nums[i]); | |||
| for (const auto &node_id : input_list) { | |||
| if (node_id == kDefaultNodeId) { | |||
| for (int32_t j = 0; j < neighbor_nums[i]; ++j) { | |||
| neighbors.emplace_back(kDefaultNodeId); | |||
| } | |||
| } else { | |||
| std::shared_ptr<Node> node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); | |||
| std::vector<NodeIdType> out; | |||
| RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); | |||
| neighbors.insert(neighbors.end(), out.begin(), out.end()); | |||
| } | |||
| } | |||
| neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end()); | |||
| input_list = std::move(neighbors); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out)); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||
| Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data, | |||
| int32_t samples_num, std::vector<NodeIdType> *out_samples) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); | |||
| std::vector<NodeIdType> shuffled_id(data.size()); | |||
| std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | |||
| std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); | |||
| for (const auto &index : shuffled_id) { | |||
| if (exclude_data.find(data[index]) != exclude_data.end()) { | |||
| continue; | |||
| } | |||
| out_samples->emplace_back(data[index]); | |||
| if (out_samples->size() >= samples_num) { | |||
| break; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| std::vector<std::vector<NodeIdType>> neighbors_vec; | |||
| neighbors_vec.resize(node_list.size()); | |||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | |||
| std::shared_ptr<Node> node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); | |||
| std::vector<NodeIdType> neighbors; | |||
| RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); | |||
| std::unordered_set<NodeIdType> exclude_node; | |||
| std::transform(neighbors.begin(), neighbors.end(), | |||
| std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_node, exclude_node.begin()), | |||
| [](const NodeIdType node) { return node; }); | |||
| auto itr = node_type_map_.find(neg_neighbor_type); | |||
| if (itr == node_type_map_.end()) { | |||
| std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| neighbors_vec[node_idx].emplace_back(node->id()); | |||
| if (itr->second.size() > exclude_node.size()) { | |||
| while (neighbors_vec[node_idx].size() < samples_num + 1) { | |||
| RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(), | |||
| &neighbors_vec[node_idx])); | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() | |||
| << " neg_neighbor_type:" << neg_neighbor_type; | |||
| // If there are no negative neighbors, they are filled with kDefaultNodeId | |||
| for (int32_t i = 0; i < samples_num; ++i) { | |||
| neighbors_vec[node_idx].emplace_back(kDefaultNodeId); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -154,7 +263,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty"); | |||
| TensorRow tensors; | |||
| for (auto f_type : feature_types) { | |||
| for (const auto &f_type : feature_types) { | |||
| std::shared_ptr<Feature> default_feature; | |||
| // If no feature can be obtained, fill in the default value | |||
| RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); | |||
| @@ -169,18 +278,14 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||
| dsize_t index = 0; | |||
| for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) { | |||
| auto itr = node_id_map_.find(*node_itr); | |||
| std::shared_ptr<Feature> feature; | |||
| if (itr != node_id_map_.end()) { | |||
| if (!itr->second->GetFeatures(f_type, &feature).IsOk()) { | |||
| feature = default_feature; | |||
| } | |||
| if (*node_itr == kDefaultNodeId) { | |||
| feature = default_feature; | |||
| } else { | |||
| if (*node_itr == kDefaultNodeId) { | |||
| std::shared_ptr<Node> node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); | |||
| if (!node->GetFeatures(f_type, &feature).IsOk()) { | |||
| feature = default_feature; | |||
| } else { | |||
| std::string err_msg = "Invalid node id:" + std::to_string(*node_itr); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); | |||
| @@ -209,35 +314,54 @@ Status Graph::Init() { | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info) { | |||
| node_info->reserve(node_type_map_.size()); | |||
| for (auto node : node_type_map_) { | |||
| NodeMetaInfo n_info; | |||
| n_info.type = node.first; | |||
| n_info.num = node.second.size(); | |||
| auto itr = node_feature_map_.find(node.first); | |||
| if (itr != node_feature_map_.end()) { | |||
| for (auto f_type : itr->second) { | |||
| n_info.feature_type.push_back(f_type); | |||
| } | |||
| std::sort(n_info.feature_type.begin(), n_info.feature_type.end()); | |||
| Status Graph::GetMetaInfo(MetaInfo *meta_info) { | |||
| meta_info->node_type.resize(node_type_map_.size()); | |||
| std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), | |||
| [](auto itr) { return itr.first; }); | |||
| std::sort(meta_info->node_type.begin(), meta_info->node_type.end()); | |||
| meta_info->edge_type.resize(edge_type_map_.size()); | |||
| std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(), | |||
| [](auto itr) { return itr.first; }); | |||
| std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end()); | |||
| for (const auto &node : node_type_map_) { | |||
| meta_info->node_num[node.first] = node.second.size(); | |||
| } | |||
| for (const auto &edge : edge_type_map_) { | |||
| meta_info->edge_num[edge.first] = edge.second.size(); | |||
| } | |||
| for (const auto &node_feature : node_feature_map_) { | |||
| for (auto type : node_feature.second) { | |||
| meta_info->node_feature_type.emplace_back(type); | |||
| } | |||
| node_info->push_back(n_info); | |||
| } | |||
| edge_info->reserve(edge_type_map_.size()); | |||
| for (auto edge : edge_type_map_) { | |||
| EdgeMetaInfo e_info; | |||
| e_info.type = edge.first; | |||
| e_info.num = edge.second.size(); | |||
| auto itr = edge_feature_map_.find(edge.first); | |||
| if (itr != edge_feature_map_.end()) { | |||
| for (auto f_type : itr->second) { | |||
| e_info.feature_type.push_back(f_type); | |||
| } | |||
| } | |||
| std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); | |||
| auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); | |||
| meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end()); | |||
| for (const auto &edge_feature : edge_feature_map_) { | |||
| for (const auto &type : edge_feature.second) { | |||
| meta_info->edge_feature_type.emplace_back(type); | |||
| } | |||
| edge_info->push_back(e_info); | |||
| } | |||
| std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); | |||
| auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); | |||
| meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end()); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GraphInfo(py::dict *out) { | |||
| MetaInfo meta_info; | |||
| RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); | |||
| (*out)["node_type"] = py::cast(meta_info.node_type); | |||
| (*out)["edge_type"] = py::cast(meta_info.edge_type); | |||
| (*out)["node_num"] = py::cast(meta_info.node_num); | |||
| (*out)["edge_num"] = py::cast(meta_info.edge_num); | |||
| (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); | |||
| (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); | |||
| return Status::OK(); | |||
| } | |||
| @@ -250,6 +374,18 @@ Status Graph::LoadNodeAndEdge() { | |||
| &node_feature_map_, &edge_feature_map_, &default_feature_map_)); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||
| auto itr = node_id_map_.find(id); | |||
| if (itr == node_id_map_.end()) { | |||
| std::string err_msg = "Invalid node id:" + std::to_string(id); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| *node = itr->second; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| @@ -33,24 +34,13 @@ namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| struct NodeMetaInfo { | |||
| NodeType type; | |||
| NodeIdType num; | |||
| std::vector<FeatureType> feature_type; | |||
| NodeMetaInfo() { | |||
| type = 0; | |||
| num = 0; | |||
| } | |||
| }; | |||
| struct EdgeMetaInfo { | |||
| EdgeType type; | |||
| EdgeIdType num; | |||
| std::vector<FeatureType> feature_type; | |||
| EdgeMetaInfo() { | |||
| type = 0; | |||
| num = 0; | |||
| } | |||
| struct MetaInfo { | |||
| std::vector<NodeType> node_type; | |||
| std::vector<EdgeType> edge_type; | |||
| std::map<NodeType, NodeIdType> node_num; | |||
| std::map<EdgeType, EdgeIdType> edge_num; | |||
| std::vector<FeatureType> node_feature_type; | |||
| std::vector<FeatureType> edge_feature_type; | |||
| }; | |||
| class Graph { | |||
| @@ -62,19 +52,23 @@ class Graph { | |||
| ~Graph() = default; | |||
| // Get the nodes from the graph. | |||
| // Get all nodes from the graph. | |||
| // @param NodeType node_type - type of node | |||
| // @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||
| // @return Status - The error code return | |||
| Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out); | |||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out); | |||
| // Get the edges from the graph. | |||
| // Get all edges from the graph. | |||
| // @param NodeType edge_type - type of edge | |||
| // @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The error code return | |||
| Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out); | |||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out); | |||
| // Get the node id from the edge. | |||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||
| // @return Status - The error code return | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out); | |||
| // All neighbors of the acquisition node. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| @@ -86,10 +80,24 @@ class Graph { | |||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out); | |||
| Status GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); | |||
| Status GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); | |||
| // Get sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status - The error code return | |||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); | |||
| // Get negative sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||
| // @return Status - The error code return | |||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); | |||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q, | |||
| NodeIdType default_node, std::shared_ptr<Tensor> *out); | |||
| @@ -112,10 +120,12 @@ class Graph { | |||
| TensorRow *out); | |||
| // Get meta information of graph | |||
| // @param std::vector<NodeMetaInfo> *node_info - Returned meta information of node | |||
| // @param std::vector<NodeMetaInfo> *node_info - Returned meta information of edge | |||
| // @param MetaInfo *meta_info - Returned meta information | |||
| // @return Status - The error code return | |||
| Status GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info); | |||
| Status GetMetaInfo(MetaInfo *meta_info); | |||
| // Return meta information to python layer | |||
| Status GraphInfo(py::dict *out); | |||
| Status Init(); | |||
| @@ -146,8 +156,24 @@ class Graph { | |||
| // @return Status - The error code return | |||
| Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | |||
| // Find node object using node id | |||
| // @param NodeIdType id - | |||
| // @param std::shared_ptr<Node> *node - Returned node object | |||
| // @return Status - The error code return | |||
| Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); | |||
| // Negative sampling | |||
| // @param std::vector<NodeIdType> &input_data - The data set to be sampled | |||
| // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded | |||
| // @param int32_t samples_num - | |||
| // @param std::vector<NodeIdType> *out_samples - Sampling results returned | |||
| // @return Status - The error code return | |||
| Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data, | |||
| int32_t samples_num, std::vector<NodeIdType> *out_samples); | |||
| std::string dataset_file_; | |||
| int32_t num_workers_; // The number of worker threads | |||
| std::mt19937 rnd_; | |||
| std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; | |||
| std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; | |||
| @@ -20,12 +20,13 @@ | |||
| #include <utility> | |||
| #include "dataset/engine/gnn/edge.h" | |||
| #include "dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type) {} | |||
| LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } | |||
| Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||
| auto itr = features_.find(feature_type); | |||
| @@ -38,21 +39,49 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> | |||
| } | |||
| } | |||
| Status LocalNode::GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) { | |||
| Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) { | |||
| std::vector<NodeIdType> neighbors; | |||
| auto itr = neighbor_nodes_.find(neighbor_type); | |||
| if (itr != neighbor_nodes_.end()) { | |||
| if (samples_num == -1) { | |||
| // Return all neighbors | |||
| neighbors.resize(itr->second.size() + 1); | |||
| neighbors[0] = id_; | |||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, | |||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||
| } else { | |||
| } | |||
| neighbors.resize(itr->second.size() + 1); | |||
| neighbors[0] = id_; | |||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, | |||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||
| } else { | |||
| neighbors.push_back(id_); | |||
| MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; | |||
| neighbors.emplace_back(id_); | |||
| } | |||
| *out_neighbors = std::move(neighbors); | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||
| std::vector<NodeIdType> *out) { | |||
| std::vector<NodeIdType> shuffled_id(neighbors.size()); | |||
| std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | |||
| std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); | |||
| int32_t num = std::min(samples_num, static_cast<int32_t>(neighbors.size())); | |||
| for (int32_t i = 0; i < num; ++i) { | |||
| out->emplace_back(neighbors[shuffled_id[i]]->id()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| std::vector<NodeIdType> *out_neighbors) { | |||
| std::vector<NodeIdType> neighbors; | |||
| neighbors.reserve(samples_num); | |||
| auto itr = neighbor_nodes_.find(neighbor_type); | |||
| if (itr != neighbor_nodes_.end()) { | |||
| while (neighbors.size() < samples_num) { | |||
| RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; | |||
| // If there are no neighbors, they are filled with kDefaultNodeId | |||
| for (int32_t i = 0; i < samples_num; ++i) { | |||
| neighbors.emplace_back(kDefaultNodeId); | |||
| } | |||
| } | |||
| *out_neighbors = std::move(neighbors); | |||
| return Status::OK(); | |||
| @@ -43,12 +43,19 @@ class LocalNode : public Node { | |||
| // @return Status - The error code return | |||
| Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | |||
| // Get the neighbors of a node | |||
| // Get the all neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) override; | |||
| Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override; | |||
| // Get the sampled neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| std::vector<NodeIdType> *out_neighbors) override; | |||
| // Add neighbor of node | |||
| // @param std::shared_ptr<Node> node - | |||
| @@ -61,6 +68,10 @@ class LocalNode : public Node { | |||
| Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | |||
| private: | |||
| Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||
| std::vector<NodeIdType> *out); | |||
| std::mt19937 rnd_; | |||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; | |||
| std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_; | |||
| }; | |||
| @@ -52,12 +52,19 @@ class Node { | |||
| // @return Status - The error code return | |||
| virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | |||
| // Get the neighbors of a node | |||
| // Get the all neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| virtual Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) = 0; | |||
| virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0; | |||
| // Get the sampled neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| std::vector<NodeIdType> *out_neighbors) = 0; | |||
| // Add neighbor of node | |||
| // @param std::shared_ptr<Node> node - | |||
| @@ -20,8 +20,9 @@ import numpy as np | |||
| from mindspore._c_dataengine import Graph | |||
| from mindspore._c_dataengine import Tensor | |||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \ | |||
| check_gnn_get_node_feature | |||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | |||
| check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | |||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature | |||
| class GraphData: | |||
| @@ -60,7 +61,44 @@ class GraphData: | |||
| Raises: | |||
| TypeError: If `node_type` is not integer. | |||
| """ | |||
| return self._graph.get_nodes(node_type, -1).as_array() | |||
| return self._graph.get_all_nodes(node_type).as_array() | |||
| @check_gnn_get_all_edges | |||
| def get_all_edges(self, edge_type): | |||
| """ | |||
| Get all edges in the graph. | |||
| Args: | |||
| edge_type (int): Specify the type of edge. | |||
| Returns: | |||
| numpy.ndarray: array of edges. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||
| >>> nodes = data_graph.get_all_edges(0) | |||
| Raises: | |||
| TypeError: If `edge_type` is not integer. | |||
| """ | |||
| return self._graph.get_all_edges(edge_type).as_array() | |||
| @check_gnn_get_nodes_from_edges | |||
| def get_nodes_from_edges(self, edge_list): | |||
| """ | |||
| Get nodes from the edges. | |||
| Args: | |||
| edge_list (list or numpy.ndarray): The given list of edges. | |||
| Returns: | |||
| numpy.ndarray: array of nodes. | |||
| Raises: | |||
| TypeError: If `edge_list` is not list or ndarray. | |||
| """ | |||
| return self._graph.get_nodes_from_edges(edge_list).as_array() | |||
| @check_gnn_get_all_neighbors | |||
| def get_all_neighbors(self, node_list, neighbor_type): | |||
| @@ -86,6 +124,58 @@ class GraphData: | |||
| """ | |||
| return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() | |||
| @check_gnn_get_sampled_neighbors | |||
| def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | |||
| """ | |||
| Get sampled neighbor information, maximum support 6-hop sampling. | |||
| Args: | |||
| node_list (list or numpy.ndarray): The given list of nodes. | |||
| neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop. | |||
| neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop. | |||
| Returns: | |||
| numpy.ndarray: array of nodes. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||
| >>> nodes = data_graph.get_all_nodes(0) | |||
| >>> neighbors = data_graph.get_all_neighbors(nodes, [2, 2], [0, 0]) | |||
| Raises: | |||
| TypeError: If `node_list` is not list or ndarray. | |||
| TypeError: If `neighbor_nums` is not list or ndarray. | |||
| TypeError: If `neighbor_types` is not list or ndarray. | |||
| """ | |||
| return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array() | |||
| @check_gnn_get_neg_sampled_neighbors | |||
| def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): | |||
| """ | |||
| Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`. | |||
| Args: | |||
| node_list (list or numpy.ndarray): The given list of nodes. | |||
| neg_neighbor_num (int): Number of neighbors sampled. | |||
| neg_neighbor_type (int): Specify the type of negative neighbor. | |||
| Returns: | |||
| numpy.ndarray: array of nodes. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||
| >>> nodes = data_graph.get_all_nodes(0) | |||
| >>> neg_neighbors = data_graph.get_neg_sampled_neighbors(nodes, 5, 0) | |||
| Raises: | |||
| TypeError: If `node_list` is not list or ndarray. | |||
| TypeError: If `neg_neighbor_num` is not integer. | |||
| TypeError: If `neg_neighbor_type` is not integer. | |||
| """ | |||
| return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array() | |||
| @check_gnn_get_node_feature | |||
| def get_node_feature(self, node_list, feature_types): | |||
| """ | |||
| @@ -111,3 +201,13 @@ class GraphData: | |||
| if isinstance(node_list, list): | |||
| node_list = np.array(node_list, dtype=np.int32) | |||
| return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] | |||
| def graph_info(self): | |||
| """ | |||
| Get the meta information of the graph, including the number of nodes, the type of nodes, | |||
| the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. | |||
| Returns: | |||
| Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | |||
| node_feature_type and edge_feature_type. | |||
| """ | |||
| return self._graph.graph_info() | |||
| @@ -1153,6 +1153,36 @@ def check_gnn_get_all_nodes(method): | |||
| return new_method | |||
| def check_gnn_get_all_edges(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| # check node_type; required argument | |||
| check_type(param_dict.get("edge_type"), 'edge_type', int) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_gnn_get_nodes_from_edges(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| # check edge_list; required argument | |||
| check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list') | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_gnn_get_all_neighbors(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" | |||
| @@ -1171,6 +1201,61 @@ def check_gnn_get_all_neighbors(method): | |||
| return new_method | |||
| def check_gnn_get_sampled_neighbors(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| # check node_list; required argument | |||
| check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') | |||
| # check neighbor_nums; required argument | |||
| neighbor_nums = param_dict.get("neighbor_nums") | |||
| check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') | |||
| if len(neighbor_nums) > 6: | |||
| raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( | |||
| 'neighbor_nums', len(neighbor_nums))) | |||
| # check neighbor_types; required argument | |||
| neighbor_types = param_dict.get("neighbor_types") | |||
| check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') | |||
| if len(neighbor_nums) > 6: | |||
| raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( | |||
| 'neighbor_types', len(neighbor_types))) | |||
| if len(neighbor_nums) != len(neighbor_types): | |||
| raise ValueError( | |||
| "The number of members of neighbor_nums and neighbor_types is inconsistent") | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_gnn_get_neg_sampled_neighbors(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| # check node_list; required argument | |||
| check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') | |||
| # check neg_neighbor_num; required argument | |||
| check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int) | |||
| # check neg_neighbor_type; required argument | |||
| check_type(param_dict.get("neg_neighbor_type"), | |||
| 'neg_neighbor_type', int) | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_aligned_list(param, param_name, membor_type): | |||
| """Check whether the structure of each member of the list is the same.""" | |||
| @@ -13,8 +13,10 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include "common/common.h" | |||
| #include "gtest/gtest.h" | |||
| @@ -45,7 +47,7 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) { | |||
| &default_feature_map) | |||
| .IsOk()); | |||
| EXPECT_EQ(n_id_map.size(), 20); | |||
| EXPECT_EQ(e_id_map.size(), 20); | |||
| EXPECT_EQ(e_id_map.size(), 40); | |||
| EXPECT_EQ(n_type_map[2].size(), 10); | |||
| EXPECT_EQ(n_type_map[1].size(), 10); | |||
| } | |||
| @@ -56,14 +58,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::vector<NodeMetaInfo> node_info; | |||
| std::vector<EdgeMetaInfo> edge_info; | |||
| s = graph.GetMetaInfo(&node_info, &edge_info); | |||
| MetaInfo meta_info; | |||
| s = graph.GetMetaInfo(&meta_info); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(node_info.size() == 2); | |||
| EXPECT_TRUE(meta_info.node_type.size() == 2); | |||
| std::shared_ptr<Tensor> nodes; | |||
| s = graph.GetNodes(node_info[1].type, -1, &nodes); | |||
| s = graph.GetAllNodes(meta_info.node_type[0], &nodes); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::vector<NodeIdType> node_list; | |||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||
| @@ -73,13 +74,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||
| } | |||
| } | |||
| std::shared_ptr<Tensor> neighbors; | |||
| s = graph.GetAllNeighbors(node_list, node_info[0].type, &neighbors); | |||
| s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); | |||
| TensorRow features; | |||
| s = graph.GetNodeFeature(nodes, node_info[1].feature_type, &features); | |||
| s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(features.size() == 3); | |||
| EXPECT_TRUE(features.size() == 4); | |||
| EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>"); | |||
| EXPECT_TRUE(features[0]->ToString() == | |||
| "Tensor (shape: <10,5>, Type: int32)\n" | |||
| @@ -91,3 +92,106 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||
| EXPECT_TRUE(features[2]->shape().ToString() == "<10>"); | |||
| EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); | |||
| } | |||
| TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| Graph graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| MetaInfo meta_info; | |||
| s = graph.GetMetaInfo(&meta_info); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(meta_info.node_type.size() == 2); | |||
| std::shared_ptr<Tensor> edges; | |||
| s = graph.GetAllEdges(meta_info.edge_type[0], &edges); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::vector<EdgeIdType> edge_list; | |||
| edge_list.resize(edges->Size()); | |||
| std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(), | |||
| [](const EdgeIdType edge) { return edge; }); | |||
| std::shared_ptr<Tensor> nodes; | |||
| s = graph.GetNodesFromEdges(edge_list, &nodes); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::unordered_set<NodeIdType> node_set; | |||
| std::vector<NodeIdType> node_list; | |||
| int index = 0; | |||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||
| index++; | |||
| if (index % 2 == 0) { | |||
| continue; | |||
| } | |||
| node_set.emplace(*itr); | |||
| if (node_set.size() >= 5) { | |||
| break; | |||
| } | |||
| } | |||
| node_list.resize(node_set.size()); | |||
| std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); | |||
| std::shared_ptr<Tensor> neighbors; | |||
| s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, | |||
| {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != | |||
| std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); | |||
| } | |||
| TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| Graph graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| MetaInfo meta_info; | |||
| s = graph.GetMetaInfo(&meta_info); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(meta_info.node_type.size() == 2); | |||
| std::shared_ptr<Tensor> nodes; | |||
| s = graph.GetAllNodes(meta_info.node_type[0], &nodes); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::vector<NodeIdType> node_list; | |||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||
| node_list.push_back(*itr); | |||
| if (node_list.size() >= 10) { | |||
| break; | |||
| } | |||
| } | |||
| std::shared_ptr<Tensor> neg_neighbors; | |||
| s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>"); | |||
| neg_neighbors.reset(); | |||
| s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); | |||
| neg_neighbors.reset(); | |||
| s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); | |||
| } | |||
| @@ -12,6 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import random | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| @@ -77,8 +78,110 @@ def test_graphdata_getnodefeature_input_check(): | |||
| g.get_node_feature(input_list, [1, "a"]) | |||
| def test_graphdata_getsampledneighbors(): | |||
| g = ds.GraphData(DATASET_FILE, 1) | |||
| edges = g.get_all_edges(0) | |||
| nodes = g.get_nodes_from_edges(edges) | |||
| assert len(nodes) == 40 | |||
| neighbor = g.get_sampled_neighbors( | |||
| np.unique(nodes[0:21, 0]), [2, 3], [2, 1]) | |||
| assert neighbor.shape == (10, 9) | |||
| def test_graphdata_getnegsampledneighbors(): | |||
| g = ds.GraphData(DATASET_FILE, 2) | |||
| nodes = g.get_all_nodes(1) | |||
| assert len(nodes) == 10 | |||
| neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2) | |||
| assert neighbor.shape == (10, 6) | |||
| def test_graphdata_graphinfo(): | |||
| g = ds.GraphData(DATASET_FILE, 2) | |||
| graph_info = g.graph_info() | |||
| assert graph_info['node_type'] == [1, 2] | |||
| assert graph_info['edge_type'] == [0] | |||
| assert graph_info['node_num'] == {1: 10, 2: 10} | |||
| assert graph_info['edge_num'] == {0: 40} | |||
| assert graph_info['node_feature_type'] == [1, 2, 3, 4] | |||
| assert graph_info['edge_feature_type'] == [] | |||
| class RandomBatchedSampler(ds.Sampler): | |||
| # RandomBatchedSampler generate random sequence without replacement in a batched manner | |||
| def __init__(self, index_range, num_edges_per_sample): | |||
| super().__init__() | |||
| self.index_range = index_range | |||
| self.num_edges_per_sample = num_edges_per_sample | |||
| def __iter__(self): | |||
| indices = [i+1 for i in range(self.index_range)] | |||
| # Reset random seed here if necessary | |||
| # random.seed(0) | |||
| random.shuffle(indices) | |||
| for i in range(0, self.index_range, self.num_edges_per_sample): | |||
| # Drop reminder | |||
| if i + self.num_edges_per_sample <= self.index_range: | |||
| yield indices[i: i + self.num_edges_per_sample] | |||
| class GNNGraphDataset(): | |||
| def __init__(self, g, batch_num): | |||
| self.g = g | |||
| self.batch_num = batch_num | |||
| def __len__(self): | |||
| # Total sample size of GNN dataset | |||
| # In this case, the size should be total_num_edges/num_edges_per_sample | |||
| return self.g.graph_info()['edge_num'][0] // self.batch_num | |||
| def __getitem__(self, index): | |||
| # index will be a list of indices yielded from RandomBatchedSampler | |||
| # Fetch edges/nodes/samples/features based on indices | |||
| nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) | |||
| nodes = nodes[:, 0] | |||
| neg_nodes = self.g.get_neg_sampled_neighbors( | |||
| node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) | |||
| nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ | |||
| 2, 2], neighbor_types=[2, 1]) | |||
| neg_nodes_neighbors = self.g.get_sampled_neighbors( | |||
| node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) | |||
| nodes_neighbors_features = self.g.get_node_feature( | |||
| node_list=nodes_neighbors, feature_types=[2, 3]) | |||
| neg_neighbors_features = self.g.get_node_feature( | |||
| node_list=neg_nodes_neighbors, feature_types=[2, 3]) | |||
| return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] | |||
| def test_graphdata_generatordataset(): | |||
| g = ds.GraphData(DATASET_FILE) | |||
| batch_num = 2 | |||
| edge_num = g.graph_info()['edge_num'][0] | |||
| out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] | |||
| dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, | |||
| sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4) | |||
| dataset = dataset.repeat(2) | |||
| itr = dataset.create_dict_iterator() | |||
| i = 0 | |||
| for data in itr: | |||
| assert data['neighbors'].shape == (2, 7) | |||
| assert data['neg_neighbors'].shape == (6, 7) | |||
| assert data['neighbors_features'].shape == (2, 7) | |||
| assert data['neg_neighbors_features'].shape == (6, 7) | |||
| i += 1 | |||
| assert i == 40 | |||
| if __name__ == '__main__': | |||
| test_graphdata_getfullneighbor() | |||
| logger.info('test_graphdata_getfullneighbor Ended.\n') | |||
| test_graphdata_getnodefeature_input_check() | |||
| logger.info('test_graphdata_getnodefeature_input_check Ended.\n') | |||
| test_graphdata_getsampledneighbors() | |||
| logger.info('test_graphdata_getsampledneighbors Ended.\n') | |||
| test_graphdata_getnegsampledneighbors() | |||
| logger.info('test_graphdata_getnegsampledneighbors Ended.\n') | |||
| test_graphdata_graphinfo() | |||
| logger.info('test_graphdata_graphinfo Ended.\n') | |||
| test_graphdata_generatordataset() | |||
| logger.info('test_graphdata_generatordataset Ended.\n') | |||