2. mod cora and citeseer conversion scripttags/v0.5.0-beta
| @@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t | |||||
| 1. Download and prepare the Cora dataset as required. | 1. Download and prepare the Cora dataset as required. | ||||
| > [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora) | |||||
| 2. Edit write_cora.sh and modify the parameters | 2. Edit write_cora.sh and modify the parameters | ||||
| ``` | ``` | ||||
| --mindrecord_file: output MindRecord file. | --mindrecord_file: output MindRecord file. | ||||
| @@ -15,29 +15,26 @@ | |||||
| """ | """ | ||||
| User-defined API for MindRecord GNN writer. | User-defined API for MindRecord GNN writer. | ||||
| """ | """ | ||||
| import csv | |||||
| import os | import os | ||||
| import pickle as pkl | |||||
| import numpy as np | import numpy as np | ||||
| import scipy.sparse as sp | import scipy.sparse as sp | ||||
| # parse args from command line parameter 'graph_api_args' | # parse args from command line parameter 'graph_api_args' | ||||
| # args delimiter is ':' | # args delimiter is ':' | ||||
| args = os.environ['graph_api_args'].split(':') | args = os.environ['graph_api_args'].split(':') | ||||
| CITESEER_CONTENT_FILE = args[0] | |||||
| CITESEER_CITES_FILE = args[1] | |||||
| CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord" | |||||
| CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord" | |||||
| node_id_map = {} | |||||
| CITESEER_PATH = args[0] | |||||
| dataset_str = 'citeseer' | |||||
| # profile: (num_features, feature_data_types, feature_shapes) | # profile: (num_features, feature_data_types, feature_shapes) | ||||
| node_profile = (2, ["float32", "int64"], [[-1], [-1]]) | |||||
| node_profile = (2, ["float32", "int32"], [[-1], [-1]]) | |||||
| edge_profile = (0, [], []) | edge_profile = (0, [], []) | ||||
| node_ids = [] | |||||
| def _normalize_citeseer_features(features): | def _normalize_citeseer_features(features): | ||||
| features = np.array(features) | |||||
| row_sum = np.array(features.sum(1)) | row_sum = np.array(features.sum(1)) | ||||
| r_inv = np.power(row_sum * 1.0, -1).flatten() | r_inv = np.power(row_sum * 1.0, -1).flatten() | ||||
| r_inv[np.isinf(r_inv)] = 0. | r_inv[np.isinf(r_inv)] = 0. | ||||
| @@ -46,6 +43,14 @@ def _normalize_citeseer_features(features): | |||||
| return features | return features | ||||
| def _parse_index_file(filename): | |||||
| """Parse index file.""" | |||||
| index = [] | |||||
| for line in open(filename): | |||||
| index.append(int(line.strip())) | |||||
| return index | |||||
| def yield_nodes(task_id=0): | def yield_nodes(task_id=0): | ||||
| """ | """ | ||||
| Generate node data | Generate node data | ||||
| @@ -54,29 +59,46 @@ def yield_nodes(task_id=0): | |||||
| data (dict): data row which is dict. | data (dict): data row which is dict. | ||||
| """ | """ | ||||
| print("Node task is {}".format(task_id)) | print("Node task is {}".format(task_id)) | ||||
| label_types = {} | |||||
| label_size = 0 | |||||
| node_num = 0 | |||||
| with open(CITESEER_CONTENT_FILE) as content_file: | |||||
| content_reader = csv.reader(content_file, delimiter='\t') | |||||
| line_count = 0 | |||||
| for row in content_reader: | |||||
| if not row[-1] in label_types: | |||||
| label_types[row[-1]] = label_size | |||||
| label_size += 1 | |||||
| if not row[0] in node_id_map: | |||||
| node_id_map[row[0]] = node_num | |||||
| node_num += 1 | |||||
| raw_features = [[int(x) for x in row[1:-1]]] | |||||
| node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features), | |||||
| 'feature_2': [label_types[row[-1]]]} | |||||
| yield node | |||||
| line_count += 1 | |||||
| names = ['x', 'y', 'tx', 'ty', 'allx', 'ally'] | |||||
| objects = [] | |||||
| for name in names: | |||||
| with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f: | |||||
| objects.append(pkl.load(f, encoding='latin1')) | |||||
| x, y, tx, ty, allx, ally = tuple(objects) | |||||
| test_idx_reorder = _parse_index_file( | |||||
| "{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str)) | |||||
| test_idx_range = np.sort(test_idx_reorder) | |||||
| tx = _normalize_citeseer_features(tx) | |||||
| allx = _normalize_citeseer_features(allx) | |||||
| # Fix citeseer dataset (there are some isolated nodes in the graph) | |||||
| # Find isolated nodes, add them as zero-vecs into the right position | |||||
| test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1) | |||||
| tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1])) | |||||
| tx_extended[test_idx_range-min(test_idx_range), :] = tx | |||||
| tx = tx_extended | |||||
| ty_extended = np.zeros((len(test_idx_range_full), y.shape[1])) | |||||
| ty_extended[test_idx_range-min(test_idx_range), :] = ty | |||||
| ty = ty_extended | |||||
| features = sp.vstack((allx, tx)).tolil() | |||||
| features[test_idx_reorder, :] = features[test_idx_range, :] | |||||
| features = features.A | |||||
| labels = np.vstack((ally, ty)) | |||||
| labels[test_idx_reorder, :] = labels[test_idx_range, :] | |||||
| line_count = 0 | |||||
| for i, label in enumerate(labels): | |||||
| if not 1 in label.tolist(): | |||||
| continue | |||||
| node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), | |||||
| 'feature_2': label.tolist().index(1)} | |||||
| line_count += 1 | |||||
| node_ids.append(i) | |||||
| yield node | |||||
| print('Processed {} lines for nodes.'.format(line_count)) | print('Processed {} lines for nodes.'.format(line_count)) | ||||
| # print('label types {}.'.format(label_types)) | |||||
| with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f: | |||||
| for k in label_types: | |||||
| print(k + ',' + str(label_types[k]), file=f) | |||||
| def yield_edges(task_id=0): | def yield_edges(task_id=0): | ||||
| @@ -87,23 +109,20 @@ def yield_edges(task_id=0): | |||||
| data (dict): data row which is dict. | data (dict): data row which is dict. | ||||
| """ | """ | ||||
| print("Edge task is {}".format(task_id)) | print("Edge task is {}".format(task_id)) | ||||
| # print(map_string_int) | |||||
| with open(CITESEER_CITES_FILE) as cites_file: | |||||
| cites_reader = csv.reader(cites_file, delimiter='\t') | |||||
| with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f: | |||||
| graph = pkl.load(f, encoding='latin1') | |||||
| line_count = 0 | line_count = 0 | ||||
| for row in cites_reader: | |||||
| if not row[0] in node_id_map: | |||||
| print('Source node {} does not exist.'.format(row[0])) | |||||
| continue | |||||
| if not row[1] in node_id_map: | |||||
| print('Destination node {} does not exist.'.format(row[1])) | |||||
| continue | |||||
| line_count += 1 | |||||
| edge = {'id': line_count, | |||||
| 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} | |||||
| yield edge | |||||
| with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f: | |||||
| for k in node_id_map: | |||||
| print(k + ',' + str(node_id_map[k]), file=f) | |||||
| for i in graph: | |||||
| for dst_id in graph[i]: | |||||
| if not i in node_ids: | |||||
| print('Source node {} does not exist.'.format(i)) | |||||
| continue | |||||
| if not dst_id in node_ids: | |||||
| print('Destination node {} does not exist.'.format( | |||||
| dst_id)) | |||||
| continue | |||||
| edge = {'id': line_count, | |||||
| 'src_id': i, 'dst_id': dst_id, 'type': 0} | |||||
| line_count += 1 | |||||
| yield edge | |||||
| print('Processed {} lines for edges.'.format(line_count)) | print('Processed {} lines for edges.'.format(line_count)) | ||||
| @@ -15,29 +15,24 @@ | |||||
| """ | """ | ||||
| User-defined API for MindRecord GNN writer. | User-defined API for MindRecord GNN writer. | ||||
| """ | """ | ||||
| import csv | |||||
| import os | import os | ||||
| import pickle as pkl | |||||
| import numpy as np | import numpy as np | ||||
| import scipy.sparse as sp | import scipy.sparse as sp | ||||
| # parse args from command line parameter 'graph_api_args' | # parse args from command line parameter 'graph_api_args' | ||||
| # args delimiter is ':' | # args delimiter is ':' | ||||
| args = os.environ['graph_api_args'].split(':') | args = os.environ['graph_api_args'].split(':') | ||||
| CORA_CONTENT_FILE = args[0] | |||||
| CORA_CITES_FILE = args[1] | |||||
| CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord" | |||||
| CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord" | |||||
| node_id_map = {} | |||||
| CORA_PATH = args[0] | |||||
| dataset_str = 'cora' | |||||
| # profile: (num_features, feature_data_types, feature_shapes) | # profile: (num_features, feature_data_types, feature_shapes) | ||||
| node_profile = (2, ["float32", "int64"], [[-1], [-1]]) | |||||
| node_profile = (2, ["float32", "int32"], [[-1], [-1]]) | |||||
| edge_profile = (0, [], []) | edge_profile = (0, [], []) | ||||
| def _normalize_cora_features(features): | def _normalize_cora_features(features): | ||||
| features = np.array(features) | |||||
| row_sum = np.array(features.sum(1)) | row_sum = np.array(features.sum(1)) | ||||
| r_inv = np.power(row_sum * 1.0, -1).flatten() | r_inv = np.power(row_sum * 1.0, -1).flatten() | ||||
| r_inv[np.isinf(r_inv)] = 0. | r_inv[np.isinf(r_inv)] = 0. | ||||
| @@ -46,6 +41,14 @@ def _normalize_cora_features(features): | |||||
| return features | return features | ||||
| def _parse_index_file(filename): | |||||
| """Parse index file.""" | |||||
| index = [] | |||||
| for line in open(filename): | |||||
| index.append(int(line.strip())) | |||||
| return index | |||||
| def yield_nodes(task_id=0): | def yield_nodes(task_id=0): | ||||
| """ | """ | ||||
| Generate node data | Generate node data | ||||
| @@ -54,32 +57,32 @@ def yield_nodes(task_id=0): | |||||
| data (dict): data row which is dict. | data (dict): data row which is dict. | ||||
| """ | """ | ||||
| print("Node task is {}".format(task_id)) | print("Node task is {}".format(task_id)) | ||||
| label_types = {} | |||||
| label_size = 0 | |||||
| node_num = 0 | |||||
| with open(CORA_CONTENT_FILE) as content_file: | |||||
| content_reader = csv.reader(content_file, delimiter=',') | |||||
| line_count = 0 | |||||
| for row in content_reader: | |||||
| if line_count == 0: | |||||
| line_count += 1 | |||||
| continue | |||||
| if not row[0] in node_id_map: | |||||
| node_id_map[row[0]] = node_num | |||||
| node_num += 1 | |||||
| if not row[-1] in label_types: | |||||
| label_types[row[-1]] = label_size | |||||
| label_size += 1 | |||||
| raw_features = [[int(x) for x in row[1:-1]]] | |||||
| node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features), | |||||
| 'feature_2': [label_types[row[-1]]]} | |||||
| yield node | |||||
| line_count += 1 | |||||
| names = ['tx', 'ty', 'allx', 'ally'] | |||||
| objects = [] | |||||
| for name in names: | |||||
| with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f: | |||||
| objects.append(pkl.load(f, encoding='latin1')) | |||||
| tx, ty, allx, ally = tuple(objects) | |||||
| test_idx_reorder = _parse_index_file( | |||||
| "{}/ind.{}.test.index".format(CORA_PATH, dataset_str)) | |||||
| test_idx_range = np.sort(test_idx_reorder) | |||||
| features = sp.vstack((allx, tx)).tolil() | |||||
| features[test_idx_reorder, :] = features[test_idx_range, :] | |||||
| features = _normalize_cora_features(features) | |||||
| features = features.A | |||||
| labels = np.vstack((ally, ty)) | |||||
| labels[test_idx_reorder, :] = labels[test_idx_range, :] | |||||
| line_count = 0 | |||||
| for i, label in enumerate(labels): | |||||
| node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(), | |||||
| 'feature_2': label.tolist().index(1)} | |||||
| line_count += 1 | |||||
| yield node | |||||
| print('Processed {} lines for nodes.'.format(line_count)) | print('Processed {} lines for nodes.'.format(line_count)) | ||||
| print('label types {}.'.format(label_types)) | |||||
| with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f: | |||||
| for k in label_types: | |||||
| print(k + ',' + str(label_types[k]), file=f) | |||||
| def yield_edges(task_id=0): | def yield_edges(task_id=0): | ||||
| @@ -90,24 +93,13 @@ def yield_edges(task_id=0): | |||||
| data (dict): data row which is dict. | data (dict): data row which is dict. | ||||
| """ | """ | ||||
| print("Edge task is {}".format(task_id)) | print("Edge task is {}".format(task_id)) | ||||
| with open(CORA_CITES_FILE) as cites_file: | |||||
| cites_reader = csv.reader(cites_file, delimiter=',') | |||||
| with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f: | |||||
| graph = pkl.load(f, encoding='latin1') | |||||
| line_count = 0 | line_count = 0 | ||||
| for row in cites_reader: | |||||
| if line_count == 0: | |||||
| for i in graph: | |||||
| for dst_id in graph[i]: | |||||
| edge = {'id': line_count, | |||||
| 'src_id': i, 'dst_id': dst_id, 'type': 0} | |||||
| line_count += 1 | line_count += 1 | ||||
| continue | |||||
| if not row[0] in node_id_map: | |||||
| print('Source node {} does not exist.'.format(row[0])) | |||||
| continue | |||||
| if not row[1] in node_id_map: | |||||
| print('Destination node {} does not exist.'.format(row[1])) | |||||
| continue | |||||
| edge = {'id': line_count, | |||||
| 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} | |||||
| yield edge | |||||
| line_count += 1 | |||||
| yield edge | |||||
| print('Processed {} lines for edges.'.format(line_count)) | print('Processed {} lines for edges.'.format(line_count)) | ||||
| with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f: | |||||
| for k in node_id_map: | |||||
| print(k + ',' + str(node_id_map[k]), file=f) | |||||
| @@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \ | |||||
| --mindrecord_partitions 1 \ | --mindrecord_partitions 1 \ | ||||
| --mindrecord_header_size_by_bit 18 \ | --mindrecord_header_size_by_bit 18 \ | ||||
| --mindrecord_page_size_by_bit 20 \ | --mindrecord_page_size_by_bit 20 \ | ||||
| --graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites" | |||||
| --graph_api_args "$SRC_PATH" | |||||
| @@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \ | |||||
| --mindrecord_partitions 1 \ | --mindrecord_partitions 1 \ | ||||
| --mindrecord_header_size_by_bit 18 \ | --mindrecord_header_size_by_bit 18 \ | ||||
| --mindrecord_page_size_by_bit 20 \ | --mindrecord_page_size_by_bit 20 \ | ||||
| --graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv" | |||||
| --graph_api_args "$SRC_PATH" | |||||
| @@ -527,10 +527,22 @@ void bindGraphData(py::module *m) { | |||||
| THROW_IF_ERROR(g_out->Init()); | THROW_IF_ERROR(g_out->Init()); | ||||
| return g_out; | return g_out; | ||||
| })) | })) | ||||
| .def("get_nodes", | |||||
| [](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) { | |||||
| .def("get_all_nodes", | |||||
| [](gnn::Graph &g, gnn::NodeType node_type) { | |||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out)); | |||||
| THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); | |||||
| return out; | |||||
| }) | |||||
| .def("get_all_edges", | |||||
| [](gnn::Graph &g, gnn::EdgeType edge_type) { | |||||
| std::shared_ptr<Tensor> out; | |||||
| THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); | |||||
| return out; | |||||
| }) | |||||
| .def("get_nodes_from_edges", | |||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) { | |||||
| std::shared_ptr<Tensor> out; | |||||
| THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | |||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_all_neighbors", | .def("get_all_neighbors", | ||||
| @@ -539,12 +551,31 @@ void bindGraphData(py::module *m) { | |||||
| THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); | THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_sampled_neighbors", | |||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, | |||||
| std::vector<gnn::NodeType> neighbor_types) { | |||||
| std::shared_ptr<Tensor> out; | |||||
| THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); | |||||
| return out; | |||||
| }) | |||||
| .def("get_neg_sampled_neighbors", | |||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, | |||||
| gnn::NodeType neg_neighbor_type) { | |||||
| std::shared_ptr<Tensor> out; | |||||
| THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); | |||||
| return out; | |||||
| }) | |||||
| .def("get_node_feature", | .def("get_node_feature", | ||||
| [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { | [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { | ||||
| TensorRow out; | TensorRow out; | ||||
| THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | ||||
| return out; | return out; | ||||
| }); | |||||
| }) | |||||
| .def("graph_info", [](gnn::Graph &g) { | |||||
| py::dict out; | |||||
| THROW_IF_ERROR(g.GraphInfo(&out)); | |||||
| return out; | |||||
| }); | |||||
| } | } | ||||
| // This is where we externalize the C logic as python modules | // This is where we externalize the C logic as python modules | ||||
| @@ -17,29 +17,30 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <functional> | #include <functional> | ||||
| #include <iterator> | |||||
| #include <numeric> | #include <numeric> | ||||
| #include <utility> | #include <utility> | ||||
| #include "dataset/core/tensor_shape.h" | #include "dataset/core/tensor_shape.h" | ||||
| #include "dataset/util/random.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace gnn { | namespace gnn { | ||||
| Graph::Graph(std::string dataset_file, int32_t num_workers) : dataset_file_(dataset_file), num_workers_(num_workers) { | |||||
| Graph::Graph(std::string dataset_file, int32_t num_workers) | |||||
| : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) { | |||||
| rnd_.seed(GetSeed()); | |||||
| MS_LOG(INFO) << "num_workers:" << num_workers; | MS_LOG(INFO) << "num_workers:" << num_workers; | ||||
| } | } | ||||
| Status Graph::GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out) { | |||||
| Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||||
| auto itr = node_type_map_.find(node_type); | auto itr = node_type_map_.find(node_type); | ||||
| if (itr == node_type_map_.end()) { | if (itr == node_type_map_.end()) { | ||||
| std::string err_msg = "Invalid node type:" + std::to_string(node_type); | std::string err_msg = "Invalid node type:" + std::to_string(node_type); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } else { | } else { | ||||
| if (node_num == -1) { | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out)); | |||||
| } else { | |||||
| } | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out)); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -59,9 +60,9 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data | |||||
| RETURN_IF_NOT_OK(Tensor::CreateTensor( | RETURN_IF_NOT_OK(Tensor::CreateTensor( | ||||
| &tensor, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(m), static_cast<dsize_t>(n)}), type, nullptr)); | &tensor, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(m), static_cast<dsize_t>(n)}), type, nullptr)); | ||||
| T *ptr = reinterpret_cast<T *>(tensor->GetMutableBuffer()); | T *ptr = reinterpret_cast<T *>(tensor->GetMutableBuffer()); | ||||
| for (auto id_m : data) { | |||||
| for (const auto &id_m : data) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); | CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); | ||||
| for (auto id_n : id_m) { | |||||
| for (const auto &id_n : id_m) { | |||||
| *ptr = id_n; | *ptr = id_n; | ||||
| ptr++; | ptr++; | ||||
| } | } | ||||
| @@ -89,7 +90,38 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out) { return Status::OK(); } | |||||
| Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||||
| auto itr = edge_type_map_.find(edge_type); | |||||
| if (itr == edge_type_map_.end()) { | |||||
| std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>({itr->second}, DataType(DataType::DE_INT32), out)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||||
| if (edge_list.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); | |||||
| } | |||||
| std::vector<std::vector<NodeIdType>> node_list; | |||||
| node_list.reserve(edge_list.size()); | |||||
| for (const auto &edge_id : edge_list) { | |||||
| auto itr = edge_id_map_.find(edge_id); | |||||
| if (itr == edge_id_map_.end()) { | |||||
| std::string err_msg = "Invalid edge id:" + std::to_string(edge_id); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> nodes; | |||||
| RETURN_IF_NOT_OK(itr->second->GetNode(&nodes)); | |||||
| node_list.push_back({nodes.first->id(), nodes.second->id()}); | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(node_list, DataType(DataType::DE_INT32), out)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out) { | std::shared_ptr<Tensor> *out) { | ||||
| @@ -105,14 +137,10 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType | |||||
| size_t max_neighbor_num = 0; | size_t max_neighbor_num = 0; | ||||
| neighbors.resize(node_list.size()); | neighbors.resize(node_list.size()); | ||||
| for (size_t i = 0; i < node_list.size(); ++i) { | for (size_t i = 0; i < node_list.size(); ++i) { | ||||
| auto itr = node_id_map_.find(node_list[i]); | |||||
| if (itr != node_id_map_.end()) { | |||||
| RETURN_IF_NOT_OK(itr->second->GetNeighbors(neighbor_type, -1, &neighbors[i])); | |||||
| max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); | |||||
| } else { | |||||
| std::string err_msg = "Invalid node id:" + std::to_string(node_list[i]); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| std::shared_ptr<Node> node; | |||||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); | |||||
| RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); | |||||
| max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); | |||||
| } | } | ||||
| RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId)); | RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId)); | ||||
| @@ -121,13 +149,94 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||||
| Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||||
| const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), | |||||
| "The sizes of neighbor_nums and neighbor_types are inconsistent."); | |||||
| std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size()); | |||||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | |||||
| neighbors_vec[node_idx].emplace_back(node_list[node_idx]); | |||||
| std::vector<NodeIdType> input_list = {node_list[node_idx]}; | |||||
| for (size_t i = 0; i < neighbor_nums.size(); ++i) { | |||||
| std::vector<NodeIdType> neighbors; | |||||
| neighbors.reserve(input_list.size() * neighbor_nums[i]); | |||||
| for (const auto &node_id : input_list) { | |||||
| if (node_id == kDefaultNodeId) { | |||||
| for (int32_t j = 0; j < neighbor_nums[i]; ++j) { | |||||
| neighbors.emplace_back(kDefaultNodeId); | |||||
| } | |||||
| } else { | |||||
| std::shared_ptr<Node> node; | |||||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); | |||||
| std::vector<NodeIdType> out; | |||||
| RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); | |||||
| neighbors.insert(neighbors.end(), out.begin(), out.end()); | |||||
| } | |||||
| } | |||||
| neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end()); | |||||
| input_list = std::move(neighbors); | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||||
| Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data, | |||||
| int32_t samples_num, std::vector<NodeIdType> *out_samples) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); | |||||
| std::vector<NodeIdType> shuffled_id(data.size()); | |||||
| std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | |||||
| std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); | |||||
| for (const auto &index : shuffled_id) { | |||||
| if (exclude_data.find(data[index]) != exclude_data.end()) { | |||||
| continue; | |||||
| } | |||||
| out_samples->emplace_back(data[index]); | |||||
| if (out_samples->size() >= samples_num) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||||
| std::vector<std::vector<NodeIdType>> neighbors_vec; | |||||
| neighbors_vec.resize(node_list.size()); | |||||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | |||||
| std::shared_ptr<Node> node; | |||||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); | |||||
| std::vector<NodeIdType> neighbors; | |||||
| RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); | |||||
| std::unordered_set<NodeIdType> exclude_node; | |||||
| std::transform(neighbors.begin(), neighbors.end(), | |||||
| std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_node, exclude_node.begin()), | |||||
| [](const NodeIdType node) { return node; }); | |||||
| auto itr = node_type_map_.find(neg_neighbor_type); | |||||
| if (itr == node_type_map_.end()) { | |||||
| std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| neighbors_vec[node_idx].emplace_back(node->id()); | |||||
| if (itr->second.size() > exclude_node.size()) { | |||||
| while (neighbors_vec[node_idx].size() < samples_num + 1) { | |||||
| RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(), | |||||
| &neighbors_vec[node_idx])); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() | |||||
| << " neg_neighbor_type:" << neg_neighbor_type; | |||||
| // If there are no negative neighbors, they are filled with kDefaultNodeId | |||||
| for (int32_t i = 0; i < samples_num; ++i) { | |||||
| neighbors_vec[node_idx].emplace_back(kDefaultNodeId); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -154,7 +263,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty"); | CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty"); | ||||
| TensorRow tensors; | TensorRow tensors; | ||||
| for (auto f_type : feature_types) { | |||||
| for (const auto &f_type : feature_types) { | |||||
| std::shared_ptr<Feature> default_feature; | std::shared_ptr<Feature> default_feature; | ||||
| // If no feature can be obtained, fill in the default value | // If no feature can be obtained, fill in the default value | ||||
| RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); | RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); | ||||
| @@ -169,18 +278,14 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||||
| dsize_t index = 0; | dsize_t index = 0; | ||||
| for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) { | for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) { | ||||
| auto itr = node_id_map_.find(*node_itr); | |||||
| std::shared_ptr<Feature> feature; | std::shared_ptr<Feature> feature; | ||||
| if (itr != node_id_map_.end()) { | |||||
| if (!itr->second->GetFeatures(f_type, &feature).IsOk()) { | |||||
| feature = default_feature; | |||||
| } | |||||
| if (*node_itr == kDefaultNodeId) { | |||||
| feature = default_feature; | |||||
| } else { | } else { | ||||
| if (*node_itr == kDefaultNodeId) { | |||||
| std::shared_ptr<Node> node; | |||||
| RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); | |||||
| if (!node->GetFeatures(f_type, &feature).IsOk()) { | |||||
| feature = default_feature; | feature = default_feature; | ||||
| } else { | |||||
| std::string err_msg = "Invalid node id:" + std::to_string(*node_itr); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); | RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); | ||||
| @@ -209,35 +314,54 @@ Status Graph::Init() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info) { | |||||
| node_info->reserve(node_type_map_.size()); | |||||
| for (auto node : node_type_map_) { | |||||
| NodeMetaInfo n_info; | |||||
| n_info.type = node.first; | |||||
| n_info.num = node.second.size(); | |||||
| auto itr = node_feature_map_.find(node.first); | |||||
| if (itr != node_feature_map_.end()) { | |||||
| for (auto f_type : itr->second) { | |||||
| n_info.feature_type.push_back(f_type); | |||||
| } | |||||
| std::sort(n_info.feature_type.begin(), n_info.feature_type.end()); | |||||
| Status Graph::GetMetaInfo(MetaInfo *meta_info) { | |||||
| meta_info->node_type.resize(node_type_map_.size()); | |||||
| std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), | |||||
| [](auto itr) { return itr.first; }); | |||||
| std::sort(meta_info->node_type.begin(), meta_info->node_type.end()); | |||||
| meta_info->edge_type.resize(edge_type_map_.size()); | |||||
| std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(), | |||||
| [](auto itr) { return itr.first; }); | |||||
| std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end()); | |||||
| for (const auto &node : node_type_map_) { | |||||
| meta_info->node_num[node.first] = node.second.size(); | |||||
| } | |||||
| for (const auto &edge : edge_type_map_) { | |||||
| meta_info->edge_num[edge.first] = edge.second.size(); | |||||
| } | |||||
| for (const auto &node_feature : node_feature_map_) { | |||||
| for (auto type : node_feature.second) { | |||||
| meta_info->node_feature_type.emplace_back(type); | |||||
| } | } | ||||
| node_info->push_back(n_info); | |||||
| } | |||||
| edge_info->reserve(edge_type_map_.size()); | |||||
| for (auto edge : edge_type_map_) { | |||||
| EdgeMetaInfo e_info; | |||||
| e_info.type = edge.first; | |||||
| e_info.num = edge.second.size(); | |||||
| auto itr = edge_feature_map_.find(edge.first); | |||||
| if (itr != edge_feature_map_.end()) { | |||||
| for (auto f_type : itr->second) { | |||||
| e_info.feature_type.push_back(f_type); | |||||
| } | |||||
| } | |||||
| std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); | |||||
| auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); | |||||
| meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end()); | |||||
| for (const auto &edge_feature : edge_feature_map_) { | |||||
| for (const auto &type : edge_feature.second) { | |||||
| meta_info->edge_feature_type.emplace_back(type); | |||||
| } | } | ||||
| edge_info->push_back(e_info); | |||||
| } | } | ||||
| std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); | |||||
| auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); | |||||
| meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Graph::GraphInfo(py::dict *out) { | |||||
| MetaInfo meta_info; | |||||
| RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); | |||||
| (*out)["node_type"] = py::cast(meta_info.node_type); | |||||
| (*out)["edge_type"] = py::cast(meta_info.edge_type); | |||||
| (*out)["node_num"] = py::cast(meta_info.node_num); | |||||
| (*out)["edge_num"] = py::cast(meta_info.edge_num); | |||||
| (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); | |||||
| (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -250,6 +374,18 @@ Status Graph::LoadNodeAndEdge() { | |||||
| &node_feature_map_, &edge_feature_map_, &default_feature_map_)); | &node_feature_map_, &edge_feature_map_, &default_feature_map_)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||||
| auto itr = node_id_map_.find(id); | |||||
| if (itr == node_id_map_.end()) { | |||||
| std::string err_msg = "Invalid node id:" + std::to_string(id); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| *node = itr->second; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace gnn | } // namespace gnn | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <map> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -33,24 +34,13 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace gnn { | namespace gnn { | ||||
| struct NodeMetaInfo { | |||||
| NodeType type; | |||||
| NodeIdType num; | |||||
| std::vector<FeatureType> feature_type; | |||||
| NodeMetaInfo() { | |||||
| type = 0; | |||||
| num = 0; | |||||
| } | |||||
| }; | |||||
| struct EdgeMetaInfo { | |||||
| EdgeType type; | |||||
| EdgeIdType num; | |||||
| std::vector<FeatureType> feature_type; | |||||
| EdgeMetaInfo() { | |||||
| type = 0; | |||||
| num = 0; | |||||
| } | |||||
| struct MetaInfo { | |||||
| std::vector<NodeType> node_type; | |||||
| std::vector<EdgeType> edge_type; | |||||
| std::map<NodeType, NodeIdType> node_num; | |||||
| std::map<EdgeType, EdgeIdType> edge_num; | |||||
| std::vector<FeatureType> node_feature_type; | |||||
| std::vector<FeatureType> edge_feature_type; | |||||
| }; | }; | ||||
| class Graph { | class Graph { | ||||
| @@ -62,19 +52,23 @@ class Graph { | |||||
| ~Graph() = default; | ~Graph() = default; | ||||
| // Get the nodes from the graph. | |||||
| // Get all nodes from the graph. | |||||
| // @param NodeType node_type - type of node | // @param NodeType node_type - type of node | ||||
| // @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired | |||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | // @param std::shared_ptr<Tensor> *out - Returned nodes id | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out); | |||||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out); | |||||
| // Get the edges from the graph. | |||||
| // Get all edges from the graph. | |||||
| // @param NodeType edge_type - type of edge | // @param NodeType edge_type - type of edge | ||||
| // @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired | |||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | // @param std::shared_ptr<Tensor> *out - Returned edge ids | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out); | |||||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out); | |||||
| // Get the node id from the edge. | |||||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||||
| // @return Status - The error code return | |||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out); | |||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| // @param std::vector<NodeType> node_list - List of nodes | // @param std::vector<NodeType> node_list - List of nodes | ||||
| @@ -86,10 +80,24 @@ class Graph { | |||||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out); | std::shared_ptr<Tensor> *out); | ||||
| Status GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); | |||||
| Status GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); | |||||
| // Get sampled neighbors. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||||
| // @return Status - The error code return | |||||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); | |||||
| // Get negative sampled neighbors. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||||
| // @return Status - The error code return | |||||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); | |||||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q, | Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q, | ||||
| NodeIdType default_node, std::shared_ptr<Tensor> *out); | NodeIdType default_node, std::shared_ptr<Tensor> *out); | ||||
| @@ -112,10 +120,12 @@ class Graph { | |||||
| TensorRow *out); | TensorRow *out); | ||||
| // Get meta information of graph | // Get meta information of graph | ||||
| // @param std::vector<NodeMetaInfo> *node_info - Returned meta information of node | |||||
| // @param std::vector<NodeMetaInfo> *node_info - Returned meta information of edge | |||||
| // @param MetaInfo *meta_info - Returned meta information | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info); | |||||
| Status GetMetaInfo(MetaInfo *meta_info); | |||||
| // Return meta information to python layer | |||||
| Status GraphInfo(py::dict *out); | |||||
| Status Init(); | Status Init(); | ||||
| @@ -146,8 +156,24 @@ class Graph { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | ||||
| // Find node object using node id | |||||
| // @param NodeIdType id - | |||||
| // @param std::shared_ptr<Node> *node - Returned node object | |||||
| // @return Status - The error code return | |||||
| Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); | |||||
| // Negative sampling | |||||
| // @param std::vector<NodeIdType> &input_data - The data set to be sampled | |||||
| // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded | |||||
| // @param int32_t samples_num - | |||||
| // @param std::vector<NodeIdType> *out_samples - Sampling results returned | |||||
| // @return Status - The error code return | |||||
| Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data, | |||||
| int32_t samples_num, std::vector<NodeIdType> *out_samples); | |||||
| std::string dataset_file_; | std::string dataset_file_; | ||||
| int32_t num_workers_; // The number of worker threads | int32_t num_workers_; // The number of worker threads | ||||
| std::mt19937 rnd_; | |||||
| std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; | std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; | ||||
| std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; | std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; | ||||
| @@ -20,12 +20,13 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "dataset/engine/gnn/edge.h" | #include "dataset/engine/gnn/edge.h" | ||||
| #include "dataset/util/random.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace gnn { | namespace gnn { | ||||
| LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type) {} | |||||
| LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } | |||||
| Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | ||||
| auto itr = features_.find(feature_type); | auto itr = features_.find(feature_type); | ||||
| @@ -38,21 +39,49 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> | |||||
| } | } | ||||
| } | } | ||||
| Status LocalNode::GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) { | |||||
| Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) { | |||||
| std::vector<NodeIdType> neighbors; | std::vector<NodeIdType> neighbors; | ||||
| auto itr = neighbor_nodes_.find(neighbor_type); | auto itr = neighbor_nodes_.find(neighbor_type); | ||||
| if (itr != neighbor_nodes_.end()) { | if (itr != neighbor_nodes_.end()) { | ||||
| if (samples_num == -1) { | |||||
| // Return all neighbors | |||||
| neighbors.resize(itr->second.size() + 1); | |||||
| neighbors[0] = id_; | |||||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, | |||||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||||
| } else { | |||||
| } | |||||
| neighbors.resize(itr->second.size() + 1); | |||||
| neighbors[0] = id_; | |||||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, | |||||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||||
| } else { | } else { | ||||
| neighbors.push_back(id_); | |||||
| MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; | MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; | ||||
| neighbors.emplace_back(id_); | |||||
| } | |||||
| *out_neighbors = std::move(neighbors); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||||
| std::vector<NodeIdType> *out) { | |||||
| std::vector<NodeIdType> shuffled_id(neighbors.size()); | |||||
| std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | |||||
| std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); | |||||
| int32_t num = std::min(samples_num, static_cast<int32_t>(neighbors.size())); | |||||
| for (int32_t i = 0; i < num; ++i) { | |||||
| out->emplace_back(neighbors[shuffled_id[i]]->id()); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||||
| std::vector<NodeIdType> *out_neighbors) { | |||||
| std::vector<NodeIdType> neighbors; | |||||
| neighbors.reserve(samples_num); | |||||
| auto itr = neighbor_nodes_.find(neighbor_type); | |||||
| if (itr != neighbor_nodes_.end()) { | |||||
| while (neighbors.size() < samples_num) { | |||||
| RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; | |||||
| // If there are no neighbors, they are filled with kDefaultNodeId | |||||
| for (int32_t i = 0; i < samples_num; ++i) { | |||||
| neighbors.emplace_back(kDefaultNodeId); | |||||
| } | |||||
| } | } | ||||
| *out_neighbors = std::move(neighbors); | *out_neighbors = std::move(neighbors); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -43,12 +43,19 @@ class LocalNode : public Node { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override; | ||||
| // Get the neighbors of a node | |||||
| // Get the all neighbors of a node | |||||
| // @param NodeType neighbor_type - type of neighbor | // @param NodeType neighbor_type - type of neighbor | ||||
| // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired | |||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) override; | |||||
| Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override; | |||||
| // Get the sampled neighbors of a node | |||||
| // @param NodeType neighbor_type - type of neighbor | |||||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||||
| // @return Status - The error code return | |||||
| Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||||
| std::vector<NodeIdType> *out_neighbors) override; | |||||
| // Add neighbor of node | // Add neighbor of node | ||||
| // @param std::shared_ptr<Node> node - | // @param std::shared_ptr<Node> node - | ||||
| @@ -61,6 +68,10 @@ class LocalNode : public Node { | |||||
| Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | ||||
| private: | private: | ||||
| Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||||
| std::vector<NodeIdType> *out); | |||||
| std::mt19937 rnd_; | |||||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; | std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; | ||||
| std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_; | std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_; | ||||
| }; | }; | ||||
| @@ -52,12 +52,19 @@ class Node { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0; | ||||
| // Get the neighbors of a node | |||||
| // Get the all neighbors of a node | |||||
| // @param NodeType neighbor_type - type of neighbor | // @param NodeType neighbor_type - type of neighbor | ||||
| // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired | |||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| virtual Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) = 0; | |||||
| virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0; | |||||
| // Get the sampled neighbors of a node | |||||
| // @param NodeType neighbor_type - type of neighbor | |||||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||||
| // @return Status - The error code return | |||||
| virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||||
| std::vector<NodeIdType> *out_neighbors) = 0; | |||||
| // Add neighbor of node | // Add neighbor of node | ||||
| // @param std::shared_ptr<Node> node - | // @param std::shared_ptr<Node> node - | ||||
| @@ -20,8 +20,9 @@ import numpy as np | |||||
| from mindspore._c_dataengine import Graph | from mindspore._c_dataengine import Graph | ||||
| from mindspore._c_dataengine import Tensor | from mindspore._c_dataengine import Tensor | ||||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \ | |||||
| check_gnn_get_node_feature | |||||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | |||||
| check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | |||||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature | |||||
| class GraphData: | class GraphData: | ||||
| @@ -60,7 +61,44 @@ class GraphData: | |||||
| Raises: | Raises: | ||||
| TypeError: If `node_type` is not integer. | TypeError: If `node_type` is not integer. | ||||
| """ | """ | ||||
| return self._graph.get_nodes(node_type, -1).as_array() | |||||
| return self._graph.get_all_nodes(node_type).as_array() | |||||
| @check_gnn_get_all_edges | |||||
| def get_all_edges(self, edge_type): | |||||
| """ | |||||
| Get all edges in the graph. | |||||
| Args: | |||||
| edge_type (int): Specify the type of edge. | |||||
| Returns: | |||||
| numpy.ndarray: array of edges. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||||
| >>> nodes = data_graph.get_all_edges(0) | |||||
| Raises: | |||||
| TypeError: If `edge_type` is not integer. | |||||
| """ | |||||
| return self._graph.get_all_edges(edge_type).as_array() | |||||
| @check_gnn_get_nodes_from_edges | |||||
| def get_nodes_from_edges(self, edge_list): | |||||
| """ | |||||
| Get nodes from the edges. | |||||
| Args: | |||||
| edge_list (list or numpy.ndarray): The given list of edges. | |||||
| Returns: | |||||
| numpy.ndarray: array of nodes. | |||||
| Raises: | |||||
| TypeError: If `edge_list` is not list or ndarray. | |||||
| """ | |||||
| return self._graph.get_nodes_from_edges(edge_list).as_array() | |||||
| @check_gnn_get_all_neighbors | @check_gnn_get_all_neighbors | ||||
| def get_all_neighbors(self, node_list, neighbor_type): | def get_all_neighbors(self, node_list, neighbor_type): | ||||
| @@ -86,6 +124,58 @@ class GraphData: | |||||
| """ | """ | ||||
| return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() | return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() | ||||
| @check_gnn_get_sampled_neighbors | |||||
| def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | |||||
| """ | |||||
| Get sampled neighbor information, maximum support 6-hop sampling. | |||||
| Args: | |||||
| node_list (list or numpy.ndarray): The given list of nodes. | |||||
| neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop. | |||||
| neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop. | |||||
| Returns: | |||||
| numpy.ndarray: array of nodes. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||||
| >>> nodes = data_graph.get_all_nodes(0) | |||||
| >>> neighbors = data_graph.get_all_neighbors(nodes, [2, 2], [0, 0]) | |||||
| Raises: | |||||
| TypeError: If `node_list` is not list or ndarray. | |||||
| TypeError: If `neighbor_nums` is not list or ndarray. | |||||
| TypeError: If `neighbor_types` is not list or ndarray. | |||||
| """ | |||||
| return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array() | |||||
| @check_gnn_get_neg_sampled_neighbors | |||||
| def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): | |||||
| """ | |||||
| Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`. | |||||
| Args: | |||||
| node_list (list or numpy.ndarray): The given list of nodes. | |||||
| neg_neighbor_num (int): Number of neighbors sampled. | |||||
| neg_neighbor_type (int): Specify the type of negative neighbor. | |||||
| Returns: | |||||
| numpy.ndarray: array of nodes. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||||
| >>> nodes = data_graph.get_all_nodes(0) | |||||
| >>> neg_neighbors = data_graph.get_neg_sampled_neighbors(nodes, 5, 0) | |||||
| Raises: | |||||
| TypeError: If `node_list` is not list or ndarray. | |||||
| TypeError: If `neg_neighbor_num` is not integer. | |||||
| TypeError: If `neg_neighbor_type` is not integer. | |||||
| """ | |||||
| return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array() | |||||
| @check_gnn_get_node_feature | @check_gnn_get_node_feature | ||||
| def get_node_feature(self, node_list, feature_types): | def get_node_feature(self, node_list, feature_types): | ||||
| """ | """ | ||||
| @@ -111,3 +201,13 @@ class GraphData: | |||||
| if isinstance(node_list, list): | if isinstance(node_list, list): | ||||
| node_list = np.array(node_list, dtype=np.int32) | node_list = np.array(node_list, dtype=np.int32) | ||||
| return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] | return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] | ||||
| def graph_info(self): | |||||
| """ | |||||
| Get the meta information of the graph, including the number of nodes, the type of nodes, | |||||
| the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. | |||||
| Returns: | |||||
| Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | |||||
| node_feature_type and edge_feature_type. | |||||
| """ | |||||
| return self._graph.graph_info() | |||||
| @@ -1153,6 +1153,36 @@ def check_gnn_get_all_nodes(method): | |||||
| return new_method | return new_method | ||||
| def check_gnn_get_all_edges(method): | |||||
| """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| # check node_type; required argument | |||||
| check_type(param_dict.get("edge_type"), 'edge_type', int) | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_gnn_get_nodes_from_edges(method): | |||||
| """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| # check edge_list; required argument | |||||
| check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list') | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_gnn_get_all_neighbors(method): | def check_gnn_get_all_neighbors(method): | ||||
| """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" | """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" | ||||
| @@ -1171,6 +1201,61 @@ def check_gnn_get_all_neighbors(method): | |||||
| return new_method | return new_method | ||||
| def check_gnn_get_sampled_neighbors(method): | |||||
| """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| # check node_list; required argument | |||||
| check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') | |||||
| # check neighbor_nums; required argument | |||||
| neighbor_nums = param_dict.get("neighbor_nums") | |||||
| check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') | |||||
| if len(neighbor_nums) > 6: | |||||
| raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( | |||||
| 'neighbor_nums', len(neighbor_nums))) | |||||
| # check neighbor_types; required argument | |||||
| neighbor_types = param_dict.get("neighbor_types") | |||||
| check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') | |||||
| if len(neighbor_nums) > 6: | |||||
| raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format( | |||||
| 'neighbor_types', len(neighbor_types))) | |||||
| if len(neighbor_nums) != len(neighbor_types): | |||||
| raise ValueError( | |||||
| "The number of members of neighbor_nums and neighbor_types is inconsistent") | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_gnn_get_neg_sampled_neighbors(method): | |||||
| """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" | |||||
| @wraps(method) | |||||
| def new_method(*args, **kwargs): | |||||
| param_dict = make_param_dict(method, args, kwargs) | |||||
| # check node_list; required argument | |||||
| check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') | |||||
| # check neg_neighbor_num; required argument | |||||
| check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int) | |||||
| # check neg_neighbor_type; required argument | |||||
| check_type(param_dict.get("neg_neighbor_type"), | |||||
| 'neg_neighbor_type', int) | |||||
| return method(*args, **kwargs) | |||||
| return new_method | |||||
| def check_aligned_list(param, param_name, membor_type): | def check_aligned_list(param, param_name, membor_type): | ||||
| """Check whether the structure of each member of the list is the same.""" | """Check whether the structure of each member of the list is the same.""" | ||||
| @@ -13,8 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <unordered_set> | |||||
| #include "common/common.h" | #include "common/common.h" | ||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| @@ -45,7 +47,7 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) { | |||||
| &default_feature_map) | &default_feature_map) | ||||
| .IsOk()); | .IsOk()); | ||||
| EXPECT_EQ(n_id_map.size(), 20); | EXPECT_EQ(n_id_map.size(), 20); | ||||
| EXPECT_EQ(e_id_map.size(), 20); | |||||
| EXPECT_EQ(e_id_map.size(), 40); | |||||
| EXPECT_EQ(n_type_map[2].size(), 10); | EXPECT_EQ(n_type_map[2].size(), 10); | ||||
| EXPECT_EQ(n_type_map[1].size(), 10); | EXPECT_EQ(n_type_map[1].size(), 10); | ||||
| } | } | ||||
| @@ -56,14 +58,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||||
| Status s = graph.Init(); | Status s = graph.Init(); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| std::vector<NodeMetaInfo> node_info; | |||||
| std::vector<EdgeMetaInfo> edge_info; | |||||
| s = graph.GetMetaInfo(&node_info, &edge_info); | |||||
| MetaInfo meta_info; | |||||
| s = graph.GetMetaInfo(&meta_info); | |||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| EXPECT_TRUE(node_info.size() == 2); | |||||
| EXPECT_TRUE(meta_info.node_type.size() == 2); | |||||
| std::shared_ptr<Tensor> nodes; | std::shared_ptr<Tensor> nodes; | ||||
| s = graph.GetNodes(node_info[1].type, -1, &nodes); | |||||
| s = graph.GetAllNodes(meta_info.node_type[0], &nodes); | |||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| std::vector<NodeIdType> node_list; | std::vector<NodeIdType> node_list; | ||||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | ||||
| @@ -73,13 +74,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<Tensor> neighbors; | std::shared_ptr<Tensor> neighbors; | ||||
| s = graph.GetAllNeighbors(node_list, node_info[0].type, &neighbors); | |||||
| s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors); | |||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); | EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); | ||||
| TensorRow features; | TensorRow features; | ||||
| s = graph.GetNodeFeature(nodes, node_info[1].feature_type, &features); | |||||
| s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features); | |||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| EXPECT_TRUE(features.size() == 3); | |||||
| EXPECT_TRUE(features.size() == 4); | |||||
| EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>"); | EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>"); | ||||
| EXPECT_TRUE(features[0]->ToString() == | EXPECT_TRUE(features[0]->ToString() == | ||||
| "Tensor (shape: <10,5>, Type: int32)\n" | "Tensor (shape: <10,5>, Type: int32)\n" | ||||
| @@ -91,3 +92,106 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||||
| EXPECT_TRUE(features[2]->shape().ToString() == "<10>"); | EXPECT_TRUE(features[2]->shape().ToString() == "<10>"); | ||||
| EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); | EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); | ||||
| } | } | ||||
| TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||||
| Graph graph(path, 1); | |||||
| Status s = graph.Init(); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| MetaInfo meta_info; | |||||
| s = graph.GetMetaInfo(&meta_info); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(meta_info.node_type.size() == 2); | |||||
| std::shared_ptr<Tensor> edges; | |||||
| s = graph.GetAllEdges(meta_info.edge_type[0], &edges); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| std::vector<EdgeIdType> edge_list; | |||||
| edge_list.resize(edges->Size()); | |||||
| std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(), | |||||
| [](const EdgeIdType edge) { return edge; }); | |||||
| std::shared_ptr<Tensor> nodes; | |||||
| s = graph.GetNodesFromEdges(edge_list, &nodes); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| std::unordered_set<NodeIdType> node_set; | |||||
| std::vector<NodeIdType> node_list; | |||||
| int index = 0; | |||||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||||
| index++; | |||||
| if (index % 2 == 0) { | |||||
| continue; | |||||
| } | |||||
| node_set.emplace(*itr); | |||||
| if (node_set.size() >= 5) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| node_list.resize(node_set.size()); | |||||
| std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); | |||||
| std::shared_ptr<Tensor> neighbors; | |||||
| s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); | |||||
| neighbors.reset(); | |||||
| s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); | |||||
| neighbors.reset(); | |||||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, | |||||
| {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); | |||||
| neighbors.reset(); | |||||
| s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); | |||||
| EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); | |||||
| neighbors.reset(); | |||||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); | |||||
| EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != | |||||
| std::string::npos); | |||||
| neighbors.reset(); | |||||
| s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors); | |||||
| EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); | |||||
| } | |||||
| TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | |||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||||
| Graph graph(path, 1); | |||||
| Status s = graph.Init(); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| MetaInfo meta_info; | |||||
| s = graph.GetMetaInfo(&meta_info); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(meta_info.node_type.size() == 2); | |||||
| std::shared_ptr<Tensor> nodes; | |||||
| s = graph.GetAllNodes(meta_info.node_type[0], &nodes); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| std::vector<NodeIdType> node_list; | |||||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||||
| node_list.push_back(*itr); | |||||
| if (node_list.size() >= 10) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| std::shared_ptr<Tensor> neg_neighbors; | |||||
| s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>"); | |||||
| neg_neighbors.reset(); | |||||
| s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors); | |||||
| EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); | |||||
| neg_neighbors.reset(); | |||||
| s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); | |||||
| EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); | |||||
| } | |||||
| @@ -12,6 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import random | |||||
| import pytest | import pytest | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| @@ -77,8 +78,110 @@ def test_graphdata_getnodefeature_input_check(): | |||||
| g.get_node_feature(input_list, [1, "a"]) | g.get_node_feature(input_list, [1, "a"]) | ||||
| def test_graphdata_getsampledneighbors(): | |||||
| g = ds.GraphData(DATASET_FILE, 1) | |||||
| edges = g.get_all_edges(0) | |||||
| nodes = g.get_nodes_from_edges(edges) | |||||
| assert len(nodes) == 40 | |||||
| neighbor = g.get_sampled_neighbors( | |||||
| np.unique(nodes[0:21, 0]), [2, 3], [2, 1]) | |||||
| assert neighbor.shape == (10, 9) | |||||
| def test_graphdata_getnegsampledneighbors(): | |||||
| g = ds.GraphData(DATASET_FILE, 2) | |||||
| nodes = g.get_all_nodes(1) | |||||
| assert len(nodes) == 10 | |||||
| neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2) | |||||
| assert neighbor.shape == (10, 6) | |||||
| def test_graphdata_graphinfo(): | |||||
| g = ds.GraphData(DATASET_FILE, 2) | |||||
| graph_info = g.graph_info() | |||||
| assert graph_info['node_type'] == [1, 2] | |||||
| assert graph_info['edge_type'] == [0] | |||||
| assert graph_info['node_num'] == {1: 10, 2: 10} | |||||
| assert graph_info['edge_num'] == {0: 40} | |||||
| assert graph_info['node_feature_type'] == [1, 2, 3, 4] | |||||
| assert graph_info['edge_feature_type'] == [] | |||||
| class RandomBatchedSampler(ds.Sampler): | |||||
| # RandomBatchedSampler generate random sequence without replacement in a batched manner | |||||
| def __init__(self, index_range, num_edges_per_sample): | |||||
| super().__init__() | |||||
| self.index_range = index_range | |||||
| self.num_edges_per_sample = num_edges_per_sample | |||||
| def __iter__(self): | |||||
| indices = [i+1 for i in range(self.index_range)] | |||||
| # Reset random seed here if necessary | |||||
| # random.seed(0) | |||||
| random.shuffle(indices) | |||||
| for i in range(0, self.index_range, self.num_edges_per_sample): | |||||
| # Drop reminder | |||||
| if i + self.num_edges_per_sample <= self.index_range: | |||||
| yield indices[i: i + self.num_edges_per_sample] | |||||
| class GNNGraphDataset(): | |||||
| def __init__(self, g, batch_num): | |||||
| self.g = g | |||||
| self.batch_num = batch_num | |||||
| def __len__(self): | |||||
| # Total sample size of GNN dataset | |||||
| # In this case, the size should be total_num_edges/num_edges_per_sample | |||||
| return self.g.graph_info()['edge_num'][0] // self.batch_num | |||||
| def __getitem__(self, index): | |||||
| # index will be a list of indices yielded from RandomBatchedSampler | |||||
| # Fetch edges/nodes/samples/features based on indices | |||||
| nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) | |||||
| nodes = nodes[:, 0] | |||||
| neg_nodes = self.g.get_neg_sampled_neighbors( | |||||
| node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) | |||||
| nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ | |||||
| 2, 2], neighbor_types=[2, 1]) | |||||
| neg_nodes_neighbors = self.g.get_sampled_neighbors( | |||||
| node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) | |||||
| nodes_neighbors_features = self.g.get_node_feature( | |||||
| node_list=nodes_neighbors, feature_types=[2, 3]) | |||||
| neg_neighbors_features = self.g.get_node_feature( | |||||
| node_list=neg_nodes_neighbors, feature_types=[2, 3]) | |||||
| return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] | |||||
| def test_graphdata_generatordataset(): | |||||
| g = ds.GraphData(DATASET_FILE) | |||||
| batch_num = 2 | |||||
| edge_num = g.graph_info()['edge_num'][0] | |||||
| out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] | |||||
| dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, | |||||
| sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4) | |||||
| dataset = dataset.repeat(2) | |||||
| itr = dataset.create_dict_iterator() | |||||
| i = 0 | |||||
| for data in itr: | |||||
| assert data['neighbors'].shape == (2, 7) | |||||
| assert data['neg_neighbors'].shape == (6, 7) | |||||
| assert data['neighbors_features'].shape == (2, 7) | |||||
| assert data['neg_neighbors_features'].shape == (6, 7) | |||||
| i += 1 | |||||
| assert i == 40 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_graphdata_getfullneighbor() | test_graphdata_getfullneighbor() | ||||
| logger.info('test_graphdata_getfullneighbor Ended.\n') | logger.info('test_graphdata_getfullneighbor Ended.\n') | ||||
| test_graphdata_getnodefeature_input_check() | test_graphdata_getnodefeature_input_check() | ||||
| logger.info('test_graphdata_getnodefeature_input_check Ended.\n') | logger.info('test_graphdata_getnodefeature_input_check Ended.\n') | ||||
| test_graphdata_getsampledneighbors() | |||||
| logger.info('test_graphdata_getsampledneighbors Ended.\n') | |||||
| test_graphdata_getnegsampledneighbors() | |||||
| logger.info('test_graphdata_getnegsampledneighbors Ended.\n') | |||||
| test_graphdata_graphinfo() | |||||
| logger.info('test_graphdata_graphinfo Ended.\n') | |||||
| test_graphdata_generatordataset() | |||||
| logger.info('test_graphdata_generatordataset Ended.\n') | |||||