| @@ -0,0 +1,81 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| User-defined API for MindRecord GNN writer. | |||
| """ | |||
| social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], | |||
| [348, 336], [348, 337], [348, 338], [348, 340], [348, 341], | |||
| [348, 342], [348, 343], [348, 344], [348, 345], [348, 346], | |||
| [348, 347], [347, 351], [347, 327], [347, 329], [347, 331], | |||
| [347, 335], [347, 341], [347, 345], [347, 346], [346, 335], | |||
| [346, 340], [346, 339], [346, 349], [346, 353], [346, 354], | |||
| [346, 341], [346, 345], [345, 335], [345, 336], [345, 341], | |||
| [344, 338], [344, 342], [343, 332], [343, 338], [343, 342], | |||
| [342, 332], [340, 349], [334, 349], [333, 349], [330, 349], | |||
| [328, 349], [359, 349], [358, 352], [358, 349], [358, 354], | |||
| [358, 356], [357, 350], [357, 354], [357, 356], [356, 350], | |||
| [355, 352], [353, 350], [352, 349], [351, 349], [350, 349]] | |||
| # profile: (num_features, feature_data_types, feature_shapes) | |||
| node_profile = (0, [], []) | |||
| edge_profile = (0, [], []) | |||
| def yield_nodes(task_id=0): | |||
| """ | |||
| Generate node data | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Node task is {}".format(task_id)) | |||
| node_list = [] | |||
| for edge in social_data: | |||
| src, dst = edge | |||
| if src not in node_list: | |||
| node_list.append(src) | |||
| if dst not in node_list: | |||
| node_list.append(dst) | |||
| node_list.sort() | |||
| print(node_list) | |||
| for node_id in node_list: | |||
| node = {'id': node_id, 'type': 1} | |||
| yield node | |||
| def yield_edges(task_id=0): | |||
| """ | |||
| Generate edge data | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("Edge task is {}".format(task_id)) | |||
| line_count = 0 | |||
| for undirected_edge in social_data: | |||
| line_count += 1 | |||
| edge = { | |||
| 'id': line_count, | |||
| 'src_id': undirected_edge[0], | |||
| 'dst_id': undirected_edge[1], | |||
| 'type': 1} | |||
| yield edge | |||
| line_count += 1 | |||
| edge = { | |||
| 'id': line_count, | |||
| 'src_id': undirected_edge[1], | |||
| 'dst_id': undirected_edge[0], | |||
| 'type': 1} | |||
| yield edge | |||
| @@ -0,0 +1,10 @@ | |||
| #!/bin/bash | |||
| MINDRECORD_PATH=/tmp/sns | |||
| rm -f $MINDRECORD_PATH/* | |||
| python writer.py --mindrecord_script sns \ | |||
| --mindrecord_file "$MINDRECORD_PATH/sns" \ | |||
| --mindrecord_partitions 1 \ | |||
| --mindrecord_header_size_by_bit 14 \ | |||
| --mindrecord_page_size_by_bit 15 | |||
| @@ -584,9 +584,16 @@ void bindGraphData(py::module *m) { | |||
| THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | |||
| return out; | |||
| }) | |||
| .def("graph_info", [](gnn::Graph &g) { | |||
| py::dict out; | |||
| THROW_IF_ERROR(g.GraphInfo(&out)); | |||
| .def("graph_info", | |||
| [](gnn::Graph &g) { | |||
| py::dict out; | |||
| THROW_IF_ERROR(g.GraphInfo(&out)); | |||
| return out; | |||
| }) | |||
| .def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path, | |||
| float step_home_param, float step_away_param, gnn::NodeIdType default_node) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); | |||
| return out; | |||
| }); | |||
| } | |||
| @@ -29,7 +29,7 @@ namespace dataset { | |||
| namespace gnn { | |||
| Graph::Graph(std::string dataset_file, int32_t num_workers) | |||
| : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) { | |||
| : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { | |||
| rnd_.seed(GetSeed()); | |||
| MS_LOG(INFO) << "num_workers:" << num_workers; | |||
| } | |||
| @@ -240,8 +240,13 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, | |||
| float q, NodeIdType default_node, std::shared_ptr<Tensor> *out) { | |||
| Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) { | |||
| RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); | |||
| std::vector<std::vector<NodeIdType>> walks; | |||
| RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); | |||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({walks}, DataType(DataType::DE_INT32), out)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -386,6 +391,195 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||
| return Status::OK(); | |||
| } | |||
| Graph::RandomWalkBase::RandomWalkBase(Graph *graph) | |||
| : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} | |||
| Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, const NodeIdType default_node, | |||
| int32_t num_walks, int32_t num_workers) { | |||
| node_list_ = node_list; | |||
| if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { | |||
| std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + | |||
| ". The size of input path is " + std::to_string(meta_path.size()); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| meta_path_ = meta_path; | |||
| if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { | |||
| std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + | |||
| std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) + | |||
| ", step_away_param: " + std::to_string(step_away_param); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| step_home_param_ = step_home_param; | |||
| step_away_param_ = step_away_param; | |||
| default_node_ = default_node; | |||
| num_walks_ = num_walks; | |||
| num_workers_ = num_workers; | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) { | |||
| // Simulate a random walk starting from start node. | |||
| auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector | |||
| // walk simulate | |||
| while (walk.size() - 1 < meta_path_.size()) { | |||
| // current nodE | |||
| auto cur_node_id = walk.back(); | |||
| std::shared_ptr<Node> cur_node; | |||
| RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node)); | |||
| // current neighbors | |||
| std::vector<NodeIdType> cur_neighbors; | |||
| RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true)); | |||
| std::sort(cur_neighbors.begin(), cur_neighbors.end()); | |||
| // break if no neighbors | |||
| if (cur_neighbors.empty()) { | |||
| break; | |||
| } | |||
| // walk by the fist node, then by the previous 2 nodes | |||
| std::shared_ptr<StochasticIndex> stochastic_index; | |||
| if (walk.size() == 1) { | |||
| RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index)); | |||
| } else { | |||
| NodeIdType prev_node_id = walk[walk.size() - 2]; | |||
| RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index)); | |||
| } | |||
| NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)]; | |||
| walk.push_back(next_node_id); | |||
| } | |||
| while (walk.size() - 1 < meta_path_.size()) { | |||
| walk.push_back(default_node_); | |||
| } | |||
| *walk_path = std::move(walk); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { | |||
| // Repeatedly simulate random walks from each node | |||
| std::vector<uint32_t> permutation(node_list_.size()); | |||
| std::iota(permutation.begin(), permutation.end(), 0); | |||
| for (int32_t i = 0; i < num_walks_; i++) { | |||
| unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); | |||
| std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed)); | |||
| for (const auto &i_perm : permutation) { | |||
| std::vector<NodeIdType> walk; | |||
| RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk)); | |||
| walks->push_back(walk); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, | |||
| std::shared_ptr<StochasticIndex> *node_probability) { | |||
| // Generate alias nodes | |||
| std::shared_ptr<Node> node; | |||
| graph_->GetNodeByNodeId(node_id, &node); | |||
| std::vector<NodeIdType> neighbors; | |||
| RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true)); | |||
| std::sort(neighbors.begin(), neighbors.end()); | |||
| auto non_normalized_probability = std::vector<float>(neighbors.size(), 1.0); | |||
| *node_probability = | |||
| std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability))); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, | |||
| std::shared_ptr<StochasticIndex> *edge_probability) { | |||
| // Get the alias edge setup lists for a given edge. | |||
| std::shared_ptr<Node> src_node; | |||
| graph_->GetNodeByNodeId(src, &src_node); | |||
| std::vector<NodeIdType> src_neighbors; | |||
| RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true)); | |||
| std::shared_ptr<Node> dst_node; | |||
| graph_->GetNodeByNodeId(dst, &dst_node); | |||
| std::vector<NodeIdType> dst_neighbors; | |||
| RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true)); | |||
| std::sort(dst_neighbors.begin(), dst_neighbors.end()); | |||
| std::vector<float> non_normalized_probability; | |||
| for (const auto &dst_nbr : dst_neighbors) { | |||
| if (dst_nbr == src) { | |||
| non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] | |||
| continue; | |||
| } | |||
| auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr); | |||
| if (it != src_neighbors.end()) { | |||
| // stay close, this node connect both src and dst | |||
| non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight'] | |||
| } else { | |||
| // step far away | |||
| non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] | |||
| } | |||
| } | |||
| *edge_probability = | |||
| std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability))); | |||
| return Status::OK(); | |||
| } | |||
| StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) { | |||
| uint32_t K = probability.size(); | |||
| std::vector<int32_t> switch_to_large_index(K, 0); | |||
| std::vector<float> weight(K, .0); | |||
| std::vector<int32_t> smaller; | |||
| std::vector<int32_t> larger; | |||
| auto random_device = GetRandomDevice(); | |||
| std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon); | |||
| float accumulate_threshold = 0.0; | |||
| for (uint32_t i = 0; i < K; i++) { | |||
| float threshold_one = distribution(random_device); | |||
| accumulate_threshold += threshold_one; | |||
| weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold; | |||
| weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i); | |||
| } | |||
| while ((!smaller.empty()) && (!larger.empty())) { | |||
| uint32_t small = smaller.back(); | |||
| smaller.pop_back(); | |||
| uint32_t large = larger.back(); | |||
| larger.pop_back(); | |||
| switch_to_large_index[small] = large; | |||
| weight[large] = weight[large] + weight[small] - 1.0; | |||
| weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large); | |||
| } | |||
| return StochasticIndex(switch_to_large_index, weight); | |||
| } | |||
| uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { | |||
| auto switch_to_large_index = stochastic_index.first; | |||
| auto weight = stochastic_index.second; | |||
| const uint32_t size_of_index = switch_to_large_index.size(); | |||
| auto random_device = GetRandomDevice(); | |||
| std::uniform_real_distribution<> distribution(0.0, 1.0); | |||
| // Generate random integer between [0, K) | |||
| uint32_t random_idx = std::floor(distribution(random_device) * size_of_index); | |||
| if (distribution(random_device) < weight[random_idx]) { | |||
| return random_idx; | |||
| } | |||
| return switch_to_large_index[random_idx]; | |||
| } | |||
| template <typename T> | |||
| std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) { | |||
| float sum_probability = | |||
| 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); | |||
| if (sum_probability < kGnnEpsilon) { | |||
| sum_probability = 1.0; | |||
| } | |||
| std::vector<float> normalized_probability; | |||
| std::transform(non_normalized_probability.begin(), non_normalized_probability.end(), | |||
| std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; }); | |||
| return normalized_probability; | |||
| } | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,12 +16,14 @@ | |||
| #ifndef DATASET_ENGINE_GNN_GRAPH_H_ | |||
| #define DATASET_ENGINE_GNN_GRAPH_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "dataset/core/tensor.h" | |||
| #include "dataset/engine/gnn/graph_loader.h" | |||
| @@ -34,6 +36,10 @@ namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| const float kGnnEpsilon = 0.0001; | |||
| const uint32_t kMaxNumWalks = 80; | |||
| using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>; | |||
| struct MetaInfo { | |||
| std::vector<NodeType> node_type; | |||
| std::vector<EdgeType> edge_type; | |||
| @@ -98,8 +104,17 @@ class Graph { | |||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); | |||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q, | |||
| NodeIdType default_node, std::shared_ptr<Tensor> *out); | |||
| // Node2vec random walk. | |||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||
| // @param std::vector<NodeType> meta_path - node type of each step | |||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status - The error code return | |||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out); | |||
| // Get the feature of a node | |||
| // @param std::shared_ptr<Tensor> nodes - List of nodes | |||
| @@ -130,6 +145,45 @@ class Graph { | |||
| Status Init(); | |||
| private: | |||
| class RandomWalkBase { | |||
| public: | |||
| explicit RandomWalkBase(Graph *graph); | |||
| Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, | |||
| int32_t num_walks = 1, int32_t num_workers = 1); | |||
| ~RandomWalkBase() = default; | |||
| Status SimulateWalk(std::vector<std::vector<NodeIdType>> *walks); | |||
| private: | |||
| Status Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path); | |||
| Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, | |||
| std::shared_ptr<StochasticIndex> *node_probability); | |||
| Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, | |||
| std::shared_ptr<StochasticIndex> *edge_probability); | |||
| static StochasticIndex GenerateProbability(const std::vector<float> &probability); | |||
| static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); | |||
| template <typename T> | |||
| std::vector<float> Normalize(const std::vector<T> &non_normalized_probability); | |||
| Graph *graph_; | |||
| std::vector<NodeIdType> node_list_; | |||
| std::vector<NodeType> meta_path_; | |||
| float step_home_param_; // Return hyper parameter. Default is 1.0 | |||
| float step_away_param_; // Inout hyper parameter. Default is 1.0 | |||
| NodeIdType default_node_; | |||
| int32_t num_walks_; // Number of walks per source. Default is 10 | |||
| int32_t num_workers_; // The number of worker threads. Default is 1 | |||
| }; | |||
| // Load graph data from mindrecord file | |||
| // @return Status - The error code return | |||
| Status LoadNodeAndEdge(); | |||
| @@ -174,6 +228,7 @@ class Graph { | |||
| std::string dataset_file_; | |||
| int32_t num_workers_; // The number of worker threads | |||
| std::mt19937 rnd_; | |||
| RandomWalkBase random_walk_; | |||
| std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; | |||
| std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; | |||
| @@ -39,17 +39,25 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> | |||
| } | |||
| } | |||
| Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) { | |||
| Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, bool exclude_itself) { | |||
| std::vector<NodeIdType> neighbors; | |||
| auto itr = neighbor_nodes_.find(neighbor_type); | |||
| if (itr != neighbor_nodes_.end()) { | |||
| neighbors.resize(itr->second.size() + 1); | |||
| neighbors[0] = id_; | |||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, | |||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||
| if (exclude_itself) { | |||
| neighbors.resize(itr->second.size()); | |||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), | |||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||
| } else { | |||
| neighbors.resize(itr->second.size() + 1); | |||
| neighbors[0] = id_; | |||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, | |||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; | |||
| neighbors.emplace_back(id_); | |||
| if (!exclude_itself) { | |||
| neighbors.emplace_back(id_); | |||
| } | |||
| } | |||
| *out_neighbors = std::move(neighbors); | |||
| return Status::OK(); | |||
| @@ -47,7 +47,8 @@ class LocalNode : public Node { | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override; | |||
| Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | |||
| bool exclude_itself = false) override; | |||
| // Get the sampled neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| @@ -56,7 +56,8 @@ class Node { | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status - The error code return | |||
| virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0; | |||
| virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, | |||
| bool exclude_itself = false) = 0; | |||
| // Get the sampled neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| @@ -22,7 +22,7 @@ from mindspore._c_dataengine import Tensor | |||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | |||
| check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | |||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature | |||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk | |||
| class GraphData: | |||
| @@ -148,7 +148,8 @@ class GraphData: | |||
| TypeError: If `neighbor_nums` is not list or ndarray. | |||
| TypeError: If `neighbor_types` is not list or ndarray. | |||
| """ | |||
| return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array() | |||
| return self._graph.get_sampled_neighbors( | |||
| node_list, neighbor_nums, neighbor_types).as_array() | |||
| @check_gnn_get_neg_sampled_neighbors | |||
| def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): | |||
| @@ -174,7 +175,8 @@ class GraphData: | |||
| TypeError: If `neg_neighbor_num` is not integer. | |||
| TypeError: If `neg_neighbor_type` is not integer. | |||
| """ | |||
| return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array() | |||
| return self._graph.get_neg_sampled_neighbors( | |||
| node_list, neg_neighbor_num, neg_neighbor_type).as_array() | |||
| @check_gnn_get_node_feature | |||
| def get_node_feature(self, node_list, feature_types): | |||
| @@ -200,7 +202,10 @@ class GraphData: | |||
| """ | |||
| if isinstance(node_list, list): | |||
| node_list = np.array(node_list, dtype=np.int32) | |||
| return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] | |||
| return [ | |||
| t.as_array() for t in self._graph.get_node_feature( | |||
| Tensor(node_list), | |||
| feature_types)] | |||
| def graph_info(self): | |||
| """ | |||
| @@ -212,3 +217,36 @@ class GraphData: | |||
| node_feature_type and edge_feature_type. | |||
| """ | |||
| return self._graph.graph_info() | |||
| @check_gnn_random_walk | |||
| def random_walk( | |||
| self, | |||
| target_nodes, | |||
| meta_path, | |||
| step_home_param=1.0, | |||
| step_away_param=1.0, | |||
| default_node=-1): | |||
| """ | |||
| Random walk in nodes. | |||
| Args: | |||
| target_nodes (list[int]): Start node list in random walk | |||
| meta_path (list[int]): node type for each walk step | |||
| step_home_param (float): return hyper parameter in node2vec algorithm | |||
| step_away_param (float): inout hyper parameter in node2vec algorithm | |||
| default_node (int): default node if no more neighbors found | |||
| Returns: | |||
| numpy.ndarray: array of nodes. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||
| >>> nodes = data_graph.random_walk([1,2], [1,2,1,2,1]) | |||
| Raises: | |||
| TypeError: If `target_nodes` is not list or ndarray. | |||
| TypeError: If `meta_path` is not list or ndarray. | |||
| """ | |||
| return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, | |||
| default_node).as_array() | |||
| @@ -1299,6 +1299,24 @@ def check_gnn_get_neg_sampled_neighbors(method): | |||
| return new_method | |||
| def check_gnn_random_walk(method): | |||
| """A wrapper that wrap a parameter checker to the GNN `random_walk` function.""" | |||
| @wraps(method) | |||
| def new_method(*args, **kwargs): | |||
| param_dict = make_param_dict(method, args, kwargs) | |||
| # check node_list; required argument | |||
| check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes') | |||
| # check meta_path; required argument | |||
| check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') | |||
| return method(*args, **kwargs) | |||
| return new_method | |||
| def check_aligned_list(param, param_name, membor_type): | |||
| """Check whether the structure of each member of the list is the same.""" | |||
| @@ -27,6 +27,13 @@ | |||
| using namespace mindspore::dataset; | |||
| using namespace mindspore::dataset::gnn; | |||
| #define print_int_vec(_i, _str) \ | |||
| do { \ | |||
| std::stringstream ss; \ | |||
| std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \ | |||
| MS_LOG(INFO) << _str << " " << ss.str(); \ | |||
| } while (false) | |||
| class MindDataTestGNNGraph : public UT::Common { | |||
| protected: | |||
| MindDataTestGNNGraph() = default; | |||
| @@ -195,3 +202,29 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | |||
| s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); | |||
| } | |||
| TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | |||
| std::string path = "data/mindrecord/testGraphData/sns"; | |||
| Graph graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| MetaInfo meta_info; | |||
| s = graph.GetMetaInfo(&meta_info); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::shared_ptr<Tensor> nodes; | |||
| s = graph.GetAllNodes(meta_info.node_type[0], &nodes); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| std::vector<NodeIdType> node_list; | |||
| for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) { | |||
| node_list.push_back(*itr); | |||
| } | |||
| print_int_vec(node_list, "node list "); | |||
| std::vector<NodeType> meta_path(59, 1); | |||
| std::shared_ptr<Tensor> walk_path; | |||
| s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); | |||
| } | |||
| @@ -19,6 +19,7 @@ import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| DATASET_FILE = "../data/mindrecord/testGraphData/testdata" | |||
| SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" | |||
| def test_graphdata_getfullneighbor(): | |||
| @@ -172,6 +173,17 @@ def test_graphdata_generatordataset(): | |||
| assert i == 40 | |||
| def test_graphdata_randomwalk(): | |||
| g = ds.GraphData(SOCIAL_DATA_FILE, 1) | |||
| nodes = g.get_all_nodes(1) | |||
| print(len(nodes)) | |||
| assert len(nodes) == 33 | |||
| meta_path = [1 for _ in range(39)] | |||
| walks = g.random_walk(nodes, meta_path) | |||
| assert walks.shape == (33, 40) | |||
| if __name__ == '__main__': | |||
| test_graphdata_getfullneighbor() | |||
| logger.info('test_graphdata_getfullneighbor Ended.\n') | |||
| @@ -185,3 +197,5 @@ if __name__ == '__main__': | |||
| logger.info('test_graphdata_graphinfo Ended.\n') | |||
| test_graphdata_generatordataset() | |||
| logger.info('test_graphdata_generatordataset Ended.\n') | |||
| test_graphdata_randomwalk() | |||
| logger.info('test_graphdata_randomwalk Ended.\n') | |||