/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef DATASET_ENGINE_GNN_GRAPH_H_ #define DATASET_ENGINE_GNN_GRAPH_H_ #include #include #include #include #include #include #include #include #include "dataset/core/tensor.h" #include "dataset/core/tensor_row.h" #include "dataset/engine/gnn/graph_loader.h" #include "dataset/engine/gnn/feature.h" #include "dataset/engine/gnn/node.h" #include "dataset/engine/gnn/edge.h" #include "dataset/util/status.h" namespace mindspore { namespace dataset { namespace gnn { const float kGnnEpsilon = 0.0001; const uint32_t kMaxNumWalks = 80; using StochasticIndex = std::pair, std::vector>; struct MetaInfo { std::vector node_type; std::vector edge_type; std::map node_num; std::map edge_num; std::vector node_feature_type; std::vector edge_feature_type; }; class Graph { public: // Constructor // @param std::string dataset_file - // @param int32_t num_workers - number of parallel threads Graph(std::string dataset_file, int32_t num_workers); ~Graph() = default; // Get all nodes from the graph. // @param NodeType node_type - type of node // @param std::shared_ptr *out - Returned nodes id // @return Status - The error code return Status GetAllNodes(NodeType node_type, std::shared_ptr *out); // Get all edges from the graph. // @param NodeType edge_type - type of edge // @param std::shared_ptr *out - Returned edge ids // @return Status - The error code return Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out); // Get the node id from the edge. // @param std::vector edge_list - List of edges // @param std::shared_ptr *out - Returned node ids // @return Status - The error code return Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out); // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors // is not enough, fill in tensor as -1. // @return Status - The error code return Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, std::shared_ptr *out); // Get sampled neighbors. // @param std::vector node_list - List of nodes // @param std::vector neighbor_nums - Number of neighbors sampled per hop // @param std::vector neighbor_types - Neighbor type sampled per hop // @param std::shared_ptr *out - Returned neighbor's id. // @return Status - The error code return Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, const std::vector &neighbor_types, std::shared_ptr *out); // Get negative sampled neighbors. // @param std::vector node_list - List of nodes // @param NodeIdType samples_num - Number of neighbors sampled // @param NodeType neg_neighbor_type - The type of negative neighbor. // @param std::shared_ptr *out - Returned negative neighbor's id. // @return Status - The error code return Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, NodeType neg_neighbor_type, std::shared_ptr *out); // Node2vec random walk. // @param std::vector node_list - List of nodes // @param std::vector meta_path - node type of each step // @param float step_home_param - return hyper parameter in node2vec algorithm // @param float step_away_param - inout hyper parameter in node2vec algorithm // @param NodeIdType default_node - default node id // @param std::shared_ptr *out - Returned nodes id in walk path // @return Status - The error code return Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, float step_home_param, float step_away_param, NodeIdType default_node, std::shared_ptr *out); // Get the feature of a node // @param std::shared_ptr nodes - List of nodes // @param std::vector feature_types - Types of features, An error will be reported if the feature type // does not exist. // @param TensorRow *out - Returned features // @return Status - The error code return Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, TensorRow *out); // Get the feature of a edge // @param std::shared_ptr edget - List of edges // @param std::vector feature_types - Types of features, An error will be reported if the feature type // does not exist. // @param Tensor *out - Returned features // @return Status - The error code return Status GetEdgeFeature(const std::shared_ptr &edget, const std::vector &feature_types, TensorRow *out); // Get meta information of graph // @param MetaInfo *meta_info - Returned meta information // @return Status - The error code return Status GetMetaInfo(MetaInfo *meta_info); // Return meta information to python layer Status GraphInfo(py::dict *out); Status Init(); private: class RandomWalkBase { public: explicit RandomWalkBase(Graph *graph); Status Build(const std::vector &node_list, const std::vector &meta_path, float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, int32_t num_walks = 1, int32_t num_workers = 1); ~RandomWalkBase() = default; Status SimulateWalk(std::vector> *walks); private: Status Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path); Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, std::shared_ptr *node_probability); Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, std::shared_ptr *edge_probability); static StochasticIndex GenerateProbability(const std::vector &probability); static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); template std::vector Normalize(const std::vector &non_normalized_probability); Graph *graph_; std::vector node_list_; std::vector meta_path_; float step_home_param_; // Return hyper parameter. Default is 1.0 float step_away_param_; // Inout hyper parameter. Default is 1.0 NodeIdType default_node_; int32_t num_walks_; // Number of walks per source. Default is 10 int32_t num_workers_; // The number of worker threads. Default is 1 }; // Load graph data from mindrecord file // @return Status - The error code return Status LoadNodeAndEdge(); // Create Tensor By Vector // @param std::vector> &data - // @param DataType type - // @param std::shared_ptr *out - // @return Status - The error code return template Status CreateTensorByVector(const std::vector> &data, DataType type, std::shared_ptr *out); // Complete vector // @param std::vector> *data - To be completed vector // @param size_t max_size - The size of the completed vector // @param T default_value - Filled default // @return Status - The error code return template Status ComplementVector(std::vector> *data, size_t max_size, T default_value); // Get the default feature of a node // @param FeatureType feature_type - // @param std::shared_ptr *out_feature - Returned feature // @return Status - The error code return Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); // Find node object using node id // @param NodeIdType id - // @param std::shared_ptr *node - Returned node object // @return Status - The error code return Status GetNodeByNodeId(NodeIdType id, std::shared_ptr *node); // Negative sampling // @param std::vector &input_data - The data set to be sampled // @param std::unordered_set &exclude_data - Data to be excluded // @param int32_t samples_num - // @param std::vector *out_samples - Sampling results returned // @return Status - The error code return Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, int32_t samples_num, std::vector *out_samples); Status CheckSamplesNum(NodeIdType samples_num); std::string dataset_file_; int32_t num_workers_; // The number of worker threads std::mt19937 rnd_; RandomWalkBase random_walk_; std::unordered_map> node_type_map_; std::unordered_map> node_id_map_; std::unordered_map> edge_type_map_; std::unordered_map> edge_id_map_; std::unordered_map> node_feature_map_; std::unordered_map> edge_feature_map_; std::unordered_map> default_feature_map_; }; } // namespace gnn } // namespace dataset } // namespace mindspore #endif // DATASET_ENGINE_GNN_GRAPH_H_