You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.h 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_GNN_GRAPH_H_
  17. #define DATASET_ENGINE_GNN_GRAPH_H_
  18. #include <algorithm>
  19. #include <memory>
  20. #include <string>
  21. #include <map>
  22. #include <unordered_map>
  23. #include <unordered_set>
  24. #include <vector>
  25. #include <utility>
  26. #include "dataset/core/tensor.h"
  27. #include "dataset/core/tensor_row.h"
  28. #include "dataset/engine/gnn/graph_loader.h"
  29. #include "dataset/engine/gnn/feature.h"
  30. #include "dataset/engine/gnn/node.h"
  31. #include "dataset/engine/gnn/edge.h"
  32. #include "dataset/util/status.h"
  33. namespace mindspore {
  34. namespace dataset {
  35. namespace gnn {
  36. const float kGnnEpsilon = 0.0001;
  37. const uint32_t kMaxNumWalks = 80;
  38. using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
  39. struct MetaInfo {
  40. std::vector<NodeType> node_type;
  41. std::vector<EdgeType> edge_type;
  42. std::map<NodeType, NodeIdType> node_num;
  43. std::map<EdgeType, EdgeIdType> edge_num;
  44. std::vector<FeatureType> node_feature_type;
  45. std::vector<FeatureType> edge_feature_type;
  46. };
  47. class Graph {
  48. public:
  49. // Constructor
  50. // @param std::string dataset_file -
  51. // @param int32_t num_workers - number of parallel threads
  52. Graph(std::string dataset_file, int32_t num_workers);
  53. ~Graph() = default;
  54. // Get all nodes from the graph.
  55. // @param NodeType node_type - type of node
  56. // @param std::shared_ptr<Tensor> *out - Returned nodes id
  57. // @return Status - The error code return
  58. Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out);
  59. // Get all edges from the graph.
  60. // @param NodeType edge_type - type of edge
  61. // @param std::shared_ptr<Tensor> *out - Returned edge ids
  62. // @return Status - The error code return
  63. Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out);
  64. // Get the node id from the edge.
  65. // @param std::vector<EdgeIdType> edge_list - List of edges
  66. // @param std::shared_ptr<Tensor> *out - Returned node ids
  67. // @return Status - The error code return
  68. Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out);
  69. // All neighbors of the acquisition node.
  70. // @param std::vector<NodeType> node_list - List of nodes
  71. // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
  72. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
  73. // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
  74. // is not enough, fill in tensor as -1.
  75. // @return Status - The error code return
  76. Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
  77. std::shared_ptr<Tensor> *out);
  78. // Get sampled neighbors.
  79. // @param std::vector<NodeType> node_list - List of nodes
  80. // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
  81. // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
  82. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
  83. // @return Status - The error code return
  84. Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
  85. const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
  86. // Get negative sampled neighbors.
  87. // @param std::vector<NodeType> node_list - List of nodes
  88. // @param NodeIdType samples_num - Number of neighbors sampled
  89. // @param NodeType neg_neighbor_type - The type of negative neighbor.
  90. // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
  91. // @return Status - The error code return
  92. Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
  93. NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
  94. // Node2vec random walk.
  95. // @param std::vector<NodeIdType> node_list - List of nodes
  96. // @param std::vector<NodeType> meta_path - node type of each step
  97. // @param float step_home_param - return hyper parameter in node2vec algorithm
  98. // @param float step_away_param - inout hyper parameter in node2vec algorithm
  99. // @param NodeIdType default_node - default node id
  100. // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
  101. // @return Status - The error code return
  102. Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
  103. float step_home_param, float step_away_param, NodeIdType default_node,
  104. std::shared_ptr<Tensor> *out);
  105. // Get the feature of a node
  106. // @param std::shared_ptr<Tensor> nodes - List of nodes
  107. // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
  108. // does not exist.
  109. // @param TensorRow *out - Returned features
  110. // @return Status - The error code return
  111. Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
  112. TensorRow *out);
  113. // Get the feature of a edge
  114. // @param std::shared_ptr<Tensor> edget - List of edges
  115. // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
  116. // does not exist.
  117. // @param Tensor *out - Returned features
  118. // @return Status - The error code return
  119. Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types,
  120. TensorRow *out);
  121. // Get meta information of graph
  122. // @param MetaInfo *meta_info - Returned meta information
  123. // @return Status - The error code return
  124. Status GetMetaInfo(MetaInfo *meta_info);
  125. // Return meta information to python layer
  126. Status GraphInfo(py::dict *out);
  127. Status Init();
  128. private:
  129. class RandomWalkBase {
  130. public:
  131. explicit RandomWalkBase(Graph *graph);
  132. Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
  133. float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
  134. int32_t num_walks = 1, int32_t num_workers = 1);
  135. ~RandomWalkBase() = default;
  136. Status SimulateWalk(std::vector<std::vector<NodeIdType>> *walks);
  137. private:
  138. Status Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path);
  139. Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
  140. std::shared_ptr<StochasticIndex> *node_probability);
  141. Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
  142. std::shared_ptr<StochasticIndex> *edge_probability);
  143. static StochasticIndex GenerateProbability(const std::vector<float> &probability);
  144. static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index);
  145. template <typename T>
  146. std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
  147. Graph *graph_;
  148. std::vector<NodeIdType> node_list_;
  149. std::vector<NodeType> meta_path_;
  150. float step_home_param_; // Return hyper parameter. Default is 1.0
  151. float step_away_param_; // Inout hyper parameter. Default is 1.0
  152. NodeIdType default_node_;
  153. int32_t num_walks_; // Number of walks per source. Default is 10
  154. int32_t num_workers_; // The number of worker threads. Default is 1
  155. };
  156. // Load graph data from mindrecord file
  157. // @return Status - The error code return
  158. Status LoadNodeAndEdge();
  159. // Create Tensor By Vector
  160. // @param std::vector<std::vector<T>> &data -
  161. // @param DataType type -
  162. // @param std::shared_ptr<Tensor> *out -
  163. // @return Status - The error code return
  164. template <typename T>
  165. Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out);
  166. // Complete vector
  167. // @param std::vector<std::vector<T>> *data - To be completed vector
  168. // @param size_t max_size - The size of the completed vector
  169. // @param T default_value - Filled default
  170. // @return Status - The error code return
  171. template <typename T>
  172. Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value);
  173. // Get the default feature of a node
  174. // @param FeatureType feature_type -
  175. // @param std::shared_ptr<Feature> *out_feature - Returned feature
  176. // @return Status - The error code return
  177. Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
  178. // Find node object using node id
  179. // @param NodeIdType id -
  180. // @param std::shared_ptr<Node> *node - Returned node object
  181. // @return Status - The error code return
  182. Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node);
  183. // Negative sampling
  184. // @param std::vector<NodeIdType> &input_data - The data set to be sampled
  185. // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
  186. // @param int32_t samples_num -
  187. // @param std::vector<NodeIdType> *out_samples - Sampling results returned
  188. // @return Status - The error code return
  189. Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
  190. int32_t samples_num, std::vector<NodeIdType> *out_samples);
  191. Status CheckSamplesNum(NodeIdType samples_num);
  192. std::string dataset_file_;
  193. int32_t num_workers_; // The number of worker threads
  194. std::mt19937 rnd_;
  195. RandomWalkBase random_walk_;
  196. std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
  197. std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
  198. std::unordered_map<EdgeType, std::vector<EdgeIdType>> edge_type_map_;
  199. std::unordered_map<EdgeIdType, std::shared_ptr<Edge>> edge_id_map_;
  200. std::unordered_map<NodeType, std::unordered_set<FeatureType>> node_feature_map_;
  201. std::unordered_map<EdgeType, std::unordered_set<FeatureType>> edge_feature_map_;
  202. std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_feature_map_;
  203. };
  204. } // namespace gnn
  205. } // namespace dataset
  206. } // namespace mindspore
  207. #endif // DATASET_ENGINE_GNN_GRAPH_H_