From 87d2c27c7f99a585a89ed28bb1012d032733df52 Mon Sep 17 00:00:00 2001 From: Jonathan Yan Date: Mon, 15 Jun 2020 08:54:10 -0400 Subject: [PATCH] random walk v1 --- example/graph_to_mindrecord/sns/__init__.py | 0 example/graph_to_mindrecord/sns/mr_api.py | 81 +++++++ example/graph_to_mindrecord/write_sns.sh | 10 + .../ccsrc/dataset/api/python_bindings.cc | 13 +- mindspore/ccsrc/dataset/engine/gnn/graph.cc | 200 +++++++++++++++++- mindspore/ccsrc/dataset/engine/gnn/graph.h | 59 +++++- .../ccsrc/dataset/engine/gnn/local_node.cc | 20 +- .../ccsrc/dataset/engine/gnn/local_node.h | 3 +- mindspore/ccsrc/dataset/engine/gnn/node.h | 3 +- mindspore/dataset/engine/graphdata.py | 46 +++- mindspore/dataset/engine/validators.py | 18 ++ tests/ut/cpp/dataset/gnn_graph_test.cc | 33 +++ tests/ut/data/mindrecord/testGraphData/sns | Bin 0 -> 58572 bytes tests/ut/data/mindrecord/testGraphData/sns.db | Bin 0 -> 24576 bytes tests/ut/python/dataset/test_graphdata.py | 14 ++ 15 files changed, 480 insertions(+), 20 deletions(-) create mode 100644 example/graph_to_mindrecord/sns/__init__.py create mode 100644 example/graph_to_mindrecord/sns/mr_api.py create mode 100644 example/graph_to_mindrecord/write_sns.sh create mode 100644 tests/ut/data/mindrecord/testGraphData/sns create mode 100644 tests/ut/data/mindrecord/testGraphData/sns.db diff --git a/example/graph_to_mindrecord/sns/__init__.py b/example/graph_to_mindrecord/sns/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/graph_to_mindrecord/sns/mr_api.py b/example/graph_to_mindrecord/sns/mr_api.py new file mode 100644 index 0000000000..4e01441601 --- /dev/null +++ b/example/graph_to_mindrecord/sns/mr_api.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord GNN writer. +""" +social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], + [348, 336], [348, 337], [348, 338], [348, 340], [348, 341], + [348, 342], [348, 343], [348, 344], [348, 345], [348, 346], + [348, 347], [347, 351], [347, 327], [347, 329], [347, 331], + [347, 335], [347, 341], [347, 345], [347, 346], [346, 335], + [346, 340], [346, 339], [346, 349], [346, 353], [346, 354], + [346, 341], [346, 345], [345, 335], [345, 336], [345, 341], + [344, 338], [344, 342], [343, 332], [343, 338], [343, 342], + [342, 332], [340, 349], [334, 349], [333, 349], [330, 349], + [328, 349], [359, 349], [358, 352], [358, 349], [358, 354], + [358, 356], [357, 350], [357, 354], [357, 356], [356, 350], + [355, 352], [353, 350], [352, 349], [351, 349], [350, 349]] + +# profile: (num_features, feature_data_types, feature_shapes) +node_profile = (0, [], []) +edge_profile = (0, [], []) + + +def yield_nodes(task_id=0): + """ + Generate node data + + Yields: + data (dict): data row which is dict. + """ + print("Node task is {}".format(task_id)) + node_list = [] + for edge in social_data: + src, dst = edge + if src not in node_list: + node_list.append(src) + if dst not in node_list: + node_list.append(dst) + node_list.sort() + print(node_list) + for node_id in node_list: + node = {'id': node_id, 'type': 1} + yield node + + +def yield_edges(task_id=0): + """ + Generate edge data + + Yields: + data (dict): data row which is dict. + """ + print("Edge task is {}".format(task_id)) + line_count = 0 + for undirected_edge in social_data: + line_count += 1 + edge = { + 'id': line_count, + 'src_id': undirected_edge[0], + 'dst_id': undirected_edge[1], + 'type': 1} + yield edge + line_count += 1 + edge = { + 'id': line_count, + 'src_id': undirected_edge[1], + 'dst_id': undirected_edge[0], + 'type': 1} + yield edge diff --git a/example/graph_to_mindrecord/write_sns.sh b/example/graph_to_mindrecord/write_sns.sh new file mode 100644 index 0000000000..f564ddc8ff --- /dev/null +++ b/example/graph_to_mindrecord/write_sns.sh @@ -0,0 +1,10 @@ +#!/bin/bash +MINDRECORD_PATH=/tmp/sns + +rm -f $MINDRECORD_PATH/* + +python writer.py --mindrecord_script sns \ +--mindrecord_file "$MINDRECORD_PATH/sns" \ +--mindrecord_partitions 1 \ +--mindrecord_header_size_by_bit 14 \ +--mindrecord_page_size_by_bit 15 diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 57fbaea027..7f3a51ffc7 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -584,9 +584,16 @@ void bindGraphData(py::module *m) { THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); return out; }) - .def("graph_info", [](gnn::Graph &g) { - py::dict out; - THROW_IF_ERROR(g.GraphInfo(&out)); + .def("graph_info", + [](gnn::Graph &g) { + py::dict out; + THROW_IF_ERROR(g.GraphInfo(&out)); + return out; + }) + .def("random_walk", [](gnn::Graph &g, std::vector node_list, std::vector meta_path, + float step_home_param, float step_away_param, gnn::NodeIdType default_node) { + std::shared_ptr out; + THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); return out; }); } diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc index 2ac3f3f5bd..791742a4fa 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.cc @@ -29,7 +29,7 @@ namespace dataset { namespace gnn { Graph::Graph(std::string dataset_file, int32_t num_workers) - : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) { + : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { rnd_.seed(GetSeed()); MS_LOG(INFO) << "num_workers:" << num_workers; } @@ -240,8 +240,13 @@ Status Graph::GetNegSampledNeighbors(const std::vector &node_list, N return Status::OK(); } -Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, float p, - float q, NodeIdType default_node, std::shared_ptr *out) { +Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out) { + RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); + std::vector> walks; + RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); + RETURN_IF_NOT_OK(CreateTensorByVector({walks}, DataType(DataType::DE_INT32), out)); return Status::OK(); } @@ -386,6 +391,195 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { return Status::OK(); } +Graph::RandomWalkBase::RandomWalkBase(Graph *graph) + : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} + +Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, const NodeIdType default_node, + int32_t num_walks, int32_t num_workers) { + node_list_ = node_list; + if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { + std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + + ". The size of input path is " + std::to_string(meta_path.size()); + RETURN_STATUS_UNEXPECTED(err_msg); + } + meta_path_ = meta_path; + if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { + std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + + std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) + + ", step_away_param: " + std::to_string(step_away_param); + RETURN_STATUS_UNEXPECTED(err_msg); + } + step_home_param_ = step_home_param; + step_away_param_ = step_away_param; + default_node_ = default_node; + num_walks_ = num_walks; + num_workers_ = num_workers; + return Status::OK(); +} + +Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { + // Simulate a random walk starting from start node. + auto walk = std::vector(1, start_node); // walk is an vector + // walk simulate + while (walk.size() - 1 < meta_path_.size()) { + // current nodE + auto cur_node_id = walk.back(); + std::shared_ptr cur_node; + RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node)); + + // current neighbors + std::vector cur_neighbors; + RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true)); + std::sort(cur_neighbors.begin(), cur_neighbors.end()); + + // break if no neighbors + if (cur_neighbors.empty()) { + break; + } + + // walk by the fist node, then by the previous 2 nodes + std::shared_ptr stochastic_index; + if (walk.size() == 1) { + RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index)); + } else { + NodeIdType prev_node_id = walk[walk.size() - 2]; + RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index)); + } + NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)]; + walk.push_back(next_node_id); + } + + while (walk.size() - 1 < meta_path_.size()) { + walk.push_back(default_node_); + } + + *walk_path = std::move(walk); + return Status::OK(); +} + +Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { + // Repeatedly simulate random walks from each node + std::vector permutation(node_list_.size()); + std::iota(permutation.begin(), permutation.end(), 0); + for (int32_t i = 0; i < num_walks_; i++) { + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed)); + for (const auto &i_perm : permutation) { + std::vector walk; + RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk)); + walks->push_back(walk); + } + } + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability) { + // Generate alias nodes + std::shared_ptr node; + graph_->GetNodeByNodeId(node_id, &node); + std::vector neighbors; + RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true)); + std::sort(neighbors.begin(), neighbors.end()); + auto non_normalized_probability = std::vector(neighbors.size(), 1.0); + *node_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability) { + // Get the alias edge setup lists for a given edge. + std::shared_ptr src_node; + graph_->GetNodeByNodeId(src, &src_node); + std::vector src_neighbors; + RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true)); + + std::shared_ptr dst_node; + graph_->GetNodeByNodeId(dst, &dst_node); + std::vector dst_neighbors; + RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true)); + + std::sort(dst_neighbors.begin(), dst_neighbors.end()); + std::vector non_normalized_probability; + for (const auto &dst_nbr : dst_neighbors) { + if (dst_nbr == src) { + non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + continue; + } + auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr); + if (it != src_neighbors.end()) { + // stay close, this node connect both src and dst + non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight'] + } else { + // step far away + non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + } + } + + *edge_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector &probability) { + uint32_t K = probability.size(); + std::vector switch_to_large_index(K, 0); + std::vector weight(K, .0); + std::vector smaller; + std::vector larger; + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon); + float accumulate_threshold = 0.0; + for (uint32_t i = 0; i < K; i++) { + float threshold_one = distribution(random_device); + accumulate_threshold += threshold_one; + weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold; + weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i); + } + + while ((!smaller.empty()) && (!larger.empty())) { + uint32_t small = smaller.back(); + smaller.pop_back(); + uint32_t large = larger.back(); + larger.pop_back(); + switch_to_large_index[small] = large; + weight[large] = weight[large] + weight[small] - 1.0; + weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large); + } + return StochasticIndex(switch_to_large_index, weight); +} + +uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { + auto switch_to_large_index = stochastic_index.first; + auto weight = stochastic_index.second; + const uint32_t size_of_index = switch_to_large_index.size(); + + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(0.0, 1.0); + + // Generate random integer between [0, K) + uint32_t random_idx = std::floor(distribution(random_device) * size_of_index); + + if (distribution(random_device) < weight[random_idx]) { + return random_idx; + } + return switch_to_large_index[random_idx]; +} + +template +std::vector Graph::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { + float sum_probability = + 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); + if (sum_probability < kGnnEpsilon) { + sum_probability = 1.0; + } + std::vector normalized_probability; + std::transform(non_normalized_probability.begin(), non_normalized_probability.end(), + std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; }); + return normalized_probability; +} } // namespace gnn } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h index 694d4eea01..79fca361ec 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.h @@ -16,12 +16,14 @@ #ifndef DATASET_ENGINE_GNN_GRAPH_H_ #define DATASET_ENGINE_GNN_GRAPH_H_ +#include #include #include #include #include #include #include +#include #include "dataset/core/tensor.h" #include "dataset/engine/gnn/graph_loader.h" @@ -34,6 +36,10 @@ namespace mindspore { namespace dataset { namespace gnn { +const float kGnnEpsilon = 0.0001; +const uint32_t kMaxNumWalks = 80; +using StochasticIndex = std::pair, std::vector>; + struct MetaInfo { std::vector node_type; std::vector edge_type; @@ -98,8 +104,17 @@ class Graph { Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, NodeType neg_neighbor_type, std::shared_ptr *out); - Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, float p, float q, - NodeIdType default_node, std::shared_ptr *out); + // Node2vec random walk. + // @param std::vector node_list - List of nodes + // @param std::vector meta_path - node type of each step + // @param float step_home_param - return hyper parameter in node2vec algorithm + // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param NodeIdType default_node - default node id + // @param std::shared_ptr *out - Returned nodes id in walk path + // @return Status - The error code return + Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out); // Get the feature of a node // @param std::shared_ptr nodes - List of nodes @@ -130,6 +145,45 @@ class Graph { Status Init(); private: + class RandomWalkBase { + public: + explicit RandomWalkBase(Graph *graph); + + Status Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, + int32_t num_walks = 1, int32_t num_workers = 1); + + ~RandomWalkBase() = default; + + Status SimulateWalk(std::vector> *walks); + + private: + Status Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path); + + Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability); + + Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability); + + static StochasticIndex GenerateProbability(const std::vector &probability); + + static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); + + template + std::vector Normalize(const std::vector &non_normalized_probability); + + Graph *graph_; + std::vector node_list_; + std::vector meta_path_; + float step_home_param_; // Return hyper parameter. Default is 1.0 + float step_away_param_; // Inout hyper parameter. Default is 1.0 + NodeIdType default_node_; + + int32_t num_walks_; // Number of walks per source. Default is 10 + int32_t num_workers_; // The number of worker threads. Default is 1 + }; + // Load graph data from mindrecord file // @return Status - The error code return Status LoadNodeAndEdge(); @@ -174,6 +228,7 @@ class Graph { std::string dataset_file_; int32_t num_workers_; // The number of worker threads std::mt19937 rnd_; + RandomWalkBase random_walk_; std::unordered_map> node_type_map_; std::unordered_map> node_id_map_; diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc index e091a52faa..c829f8e8ca 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc @@ -39,17 +39,25 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr } } -Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) { +Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, bool exclude_itself) { std::vector neighbors; auto itr = neighbor_nodes_.find(neighbor_type); if (itr != neighbor_nodes_.end()) { - neighbors.resize(itr->second.size() + 1); - neighbors[0] = id_; - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, - [](const std::shared_ptr node) { return node->id(); }); + if (exclude_itself) { + neighbors.resize(itr->second.size()); + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), + [](const std::shared_ptr node) { return node->id(); }); + } else { + neighbors.resize(itr->second.size() + 1); + neighbors[0] = id_; + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + [](const std::shared_ptr node) { return node->id(); }); + } } else { MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; - neighbors.emplace_back(id_); + if (!exclude_itself) { + neighbors.emplace_back(id_); + } } *out_neighbors = std::move(neighbors); return Status::OK(); diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/dataset/engine/gnn/local_node.h index b9b007c420..bc069d073f 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.h @@ -47,7 +47,8 @@ class LocalNode : public Node { // @param NodeType neighbor_type - type of neighbor // @param std::vector *out_neighbors - Returned neighbors id // @return Status - The error code return - Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) override; + Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) override; // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor diff --git a/mindspore/ccsrc/dataset/engine/gnn/node.h b/mindspore/ccsrc/dataset/engine/gnn/node.h index f0136e92d7..282f856797 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/dataset/engine/gnn/node.h @@ -56,7 +56,8 @@ class Node { // @param NodeType neighbor_type - type of neighbor // @param std::vector *out_neighbors - Returned neighbors id // @return Status - The error code return - virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors) = 0; + virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) = 0; // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index e6ff22dd0d..838dd53f0a 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -22,7 +22,7 @@ from mindspore._c_dataengine import Tensor from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ - check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature + check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk class GraphData: @@ -148,7 +148,8 @@ class GraphData: TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray. """ - return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array() + return self._graph.get_sampled_neighbors( + node_list, neighbor_nums, neighbor_types).as_array() @check_gnn_get_neg_sampled_neighbors def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): @@ -174,7 +175,8 @@ class GraphData: TypeError: If `neg_neighbor_num` is not integer. TypeError: If `neg_neighbor_type` is not integer. """ - return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array() + return self._graph.get_neg_sampled_neighbors( + node_list, neg_neighbor_num, neg_neighbor_type).as_array() @check_gnn_get_node_feature def get_node_feature(self, node_list, feature_types): @@ -200,7 +202,10 @@ class GraphData: """ if isinstance(node_list, list): node_list = np.array(node_list, dtype=np.int32) - return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] + return [ + t.as_array() for t in self._graph.get_node_feature( + Tensor(node_list), + feature_types)] def graph_info(self): """ @@ -212,3 +217,36 @@ class GraphData: node_feature_type and edge_feature_type. """ return self._graph.graph_info() + + @check_gnn_random_walk + def random_walk( + self, + target_nodes, + meta_path, + step_home_param=1.0, + step_away_param=1.0, + default_node=-1): + """ + Random walk in nodes. + + Args: + target_nodes (list[int]): Start node list in random walk + meta_path (list[int]): node type for each walk step + step_home_param (float): return hyper parameter in node2vec algorithm + step_away_param (float): inout hyper parameter in node2vec algorithm + default_node (int): default node if no more neighbors found + + Returns: + numpy.ndarray: array of nodes. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.random_walk([1,2], [1,2,1,2,1]) + + Raises: + TypeError: If `target_nodes` is not list or ndarray. + TypeError: If `meta_path` is not list or ndarray. + """ + return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, + default_node).as_array() diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ff434c718e..54dc6afd6d 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1299,6 +1299,24 @@ def check_gnn_get_neg_sampled_neighbors(method): return new_method +def check_gnn_random_walk(method): + """A wrapper that wrap a parameter checker to the GNN `random_walk` function.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_list; required argument + check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes') + + # check meta_path; required argument + check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') + + return method(*args, **kwargs) + + return new_method + + def check_aligned_list(param, param_name, membor_type): """Check whether the structure of each member of the list is the same.""" diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 7c644a3ae7..ce2aca4ffd 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -27,6 +27,13 @@ using namespace mindspore::dataset; using namespace mindspore::dataset::gnn; +#define print_int_vec(_i, _str) \ + do { \ + std::stringstream ss; \ + std::copy(_i.begin(), _i.end(), std::ostream_iterator(ss, " ")); \ + MS_LOG(INFO) << _str << " " << ss.str(); \ + } while (false) + class MindDataTestGNNGraph : public UT::Common { protected: MindDataTestGNNGraph() = default; @@ -195,3 +202,29 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); } + +TEST_F(MindDataTestGNNGraph, TestRandomWalk) { + std::string path = "data/mindrecord/testGraphData/sns"; + Graph graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + MetaInfo meta_info; + s = graph.GetMetaInfo(&meta_info); + EXPECT_TRUE(s.IsOk()); + + std::shared_ptr nodes; + s = graph.GetAllNodes(meta_info.node_type[0], &nodes); + EXPECT_TRUE(s.IsOk()); + std::vector node_list; + for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { + node_list.push_back(*itr); + } + + print_int_vec(node_list, "node list "); + std::vector meta_path(59, 1); + std::shared_ptr walk_path; + s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); +} \ No newline at end of file diff --git a/tests/ut/data/mindrecord/testGraphData/sns b/tests/ut/data/mindrecord/testGraphData/sns new file mode 100644 index 0000000000000000000000000000000000000000..37a2c3dd30b7344083968973520f83702ec6266b GIT binary patch literal 58572 zcmeI1>2lLn9L1y3(tY1sx~Y3>Xm-F|lhOi7AP@opLNKx=+iHWIN(xIe`Pc{POh3TG z=w12-^vV|Y5l!G|?sSHko_xsIl72c@KV98><^SJ%MtXl+ZaeV&AgaaT?}1w#o1Qo{ zS#iUTAKa)l!(iQy-RkNVZg8uvmn2bG-%J9x;>L}3aMi1YzFQsLuH1WMGmPRS+xega zJB+vkZ*Kgnfom0a!)uufHP;lFyr_G<*YvTeiT~mVE}5OO6?Hc^%!XWE=BDCyo6R^#GEJrxaog)(MYbnS zyeJv8rKb@aXO8HlWRAxQO*=d}es^cyw56Gq|FUJdzr43hj31mfYWiKE=d`}wtq=Q+ z8~8>}Gw_nl$oR=Rl`+=o`uA=$r}cxlVboqVM@6PD7&)!T+i2JFlMhq2p!-#~-2A#@ z(<3`Ace}Dvra8>`&Tif6{xOr4Ghrv07|)PFFCCuG?3*Oa$!d)Qw*-^hvMAnR9d+=*RtG57YQB#e`=VUFl*uCG_zYe5u;-HuT^abVQl z?mzYVF9R|l12P~3G9UvoAOkWW12P~3G9UvoAOkWW12P~3G9UvoAOkWW12P~3G9Uvo zAOkWW12P~3G9UvoAOkWW12P~3G9UvoAOkWW12P~3G9UvoAOkY+Fd5iaY|LKs)fu^y zZ_K~xn1LkuQQsID(Q|P}V}iy6V?g7f#=}7`(3qexL1Ti}&%LKWYb33a2E(D&33{g< z&J0>73?2-vpS6D0`q}*A_!n z+i}q7c6{??}3nKX2D*uU_A@AZoxJzSl@yT zEZC+6+qw@sy~Erc4CZcUNT$xE!Q7oXm;1B#{pm-Lz6|swhTelTW*=qfN0B~SHv2K8 zG5a`^y%*`@WwReg8nfq_>?e?(FPr@&(wM!-WIu)UV%hAck;d#(O!hNKpDLUEEYg^L zn#q0+>C^G2JDVzN!(wKdo$$ksz^JTN&MjEp(FxmT%zEC#%9i%b)B9r|t z(ih8SzlSttuQJ*1BfVNS`-9!+8k6lJU4v|@jZ#nCwGHua(U{j5KC@Om-D%uWa_$NMm-L$^HiEdfDu6k;ZIlZ21mp99yXS z{XNo{O^q!-AdO=Sl|8c?O+DYUNaOQOwfskTf32Drx?)T3~b@0t%fqA~6Zjm|*@O)Em zHBh-IW>e0gc2RUrWkdC%m`yo{`bE(>l?@e)B277m8b;AM)pDp}qz=x&)I0`tjMTw- zjGFzSl94(mp|?=kP|GN0pJaL}R5McNMA>YpXB4wfFxgPiNS(Q|*-+CcX3sI%P}NAC zV`a0Uu2IZB#$-cfBXwG3v!S+8%x*E+P~AwKCS+e>XsBZ||(0t5&wg8)fL=(Xs%dN1g;M@S$pLM$L-*&r<2 zGM2Gy3oNJ9LsRUbDdP|i;}Q>(5RdCpmnIqyNr+1mjY}OSr5-h5TqV|c`V+`!{Wz2S z)0xhs_u4Ccj_!x=-u>zBJ?FsLTUQh6jgdh2zJa|RvA|l+&LPC*2Lc?&nc!i8#}}Ic zVE$)cZ0HO4Bi~Kj$}PW;)mhFcM>+X>>OaCaebXzTS3s|TUID!VdIj_f=oQc_pjSYz zz@KdeDvXxI^mNn}>*(wwb)^NhMQxP@+l%95+FDy&5H1dc3kpNUf%wle=e0n%xG5Z{ ztb#{ZgmqQFEeslNQ2fJfEy#tKBISYFq>7Rc_IE4GLf$G|_?FF^Xfr{eh%m=@` z`%jhq;BX#1?ftK7?|;o%x^#uXl30*}xLz$v1`q7-gIRrB$HCaX`|Y-cQ=x730vO@V zW}AP*nc;!d&*sshex{D8d(?7uscKVxt-Pnari>_EN}-ajaPqI@AIN_x56P|a202NZ zm3}V0C7qIDQk}F$a)>kHPs9u2s2COVMV~M$+z>7bV?vKmBn0?5{w9BkALsk{Qa*#{ zY`1KeZ4yK5HH_H<@$H zcGF$cHPboMF;m2pYjPWBjMt6ljU&c(W1i7#xM!F$Tri9pqK15f56z+*=pq_JJ*Wr; zxViZk$J(&wLcGCXv%y2PA)5jI3q`1s-$L<3djlJJq%}e((a`5${kfA<){HQ;oS_UHz_M=xALxXhm@fv)PQ># zPLKDKq%7sE@-8VwwUE=paCUhwdOuNyl(VD+a%%A|hSTjG@Lp7!lp(SW)uU$I&2We} z$2*|pC{3goHKTe=7)~_tQ*Vx9S8_-Zs)n2>!|6)AocO7HSFw|=kW-Dj7*1#6k;KdL zHTf?d`QWtRc81fIa5dp`IU*k; zTTlx;R~y6GnQ%Pes+=oF$Y!Wd3*O0ab|e%f9GBg4F4+V*^>_!vi6mGOiliCYP4ZAD zePBx$p$Z26X&7P3wgft?ZCl5nT2X1CK zO&-M)lG>#al8bggP7}jvbWeE{DNkxA4?)fj+{kbm+%LMPB(Ic5)}uPeX<#_@?y&ns z@t))*>ma8N*E5{3JHZ_mr^I_C2ZbRg%y8;lH{A(v&6y%=At#LM7*4J0TdtdMRk}dd z!1>dHYZ*?BE8_YVT(?HaYRIX=H4LZPmF9}Tl`TrL!D+?S45!LD?Mf4UVm?`gTH(2> z7*3`0+s`j2?F=X6T<2#N!aFre-%b-53xSZjXIr<#mf$Mk=SqeE-xQyYHI#xRR zgn&>)GEoKOlro$W`;22HKPLpp63D5*B@Abq{Z0D}^cHhuF)D+cZ49T_9<#p*{mD(T z2y)7BF~cdcueZmb$GJo@PzmG|F`TXHtbIN7N#i64IVE^2!zonXR%fBt>LVDnK~5pV zDNu*hx1nDvCFzjUh6@-@zPec*;yFHp1i*>le1@|{<gefB%hzg{M(kkf)UF`PVQn=%4D@C5O}c^JWY3}>TaQ?@}L-A_^= zrx|Z#I2+{m6&v*8A(9MzN(66UI1kHX^7}TEEt4ccP7{8Z;pEDd@|bnnW+GnbD&X~YjPob}QNvK>YfS4aZ%rxCoK;jEJ;qz_=cF-bg-(}348oE)h^ znt%~WjJV%l_mouDY z{BMLz^Bs#s1i1c0@G^$8l)u3L2F8JR2oE_`cqzll%w?uD)Z%wB1Sk9+ zVldUYprK~JlOZ^abEaIF35{x~$?sqYcH>p!988p=8fx_08G;H`ylU_n@-@`pR~dq0 zEH#dxS%XhQ^?rpR$VSOn3bU|T4Tb$OLy!zpMhPZqH#Ah|ml%R*m@rJi9PXlqYW*TZ z5DZO*37F!IX{g38Fa#c6K@-dXdo)z-=NW{d!xj7WjP^I6> z5G?Tg=eV2PoQAghEeyeovQQ6qiMy$xkl)M@Oxz60;>Nj48mdS$EeL3SKDzMcHNu~s z>Ep&VRPHmve%NP2D8An?$XzP!(@y80vaE%k5Jv+61JgnCqsslDn>wN5QnH>qpXC8|$#s21fj zWk&gT{Ps&H+eR5Q8hTeZG%&Jz%SWc8B8A+c?{~=9FKb78>ekfg%{ziIL zdPy3U9+L*7T~b7(rxhLX{VeWXmaf}-l?~hC2A7PGi$Akq&j~}J@*y9w39;5i^5XB>pQhekH z#lc4?9v-B4=rF~DhbYDlQali&IBO0txL2dtx0m9cK8m~dQ0(1J zv8R{ft{#fryC{-wicvzbD@w7mi(*G7#r6)0ZS53ywo%-%lVW5C#nuSLmR5?*EfkxY zDK<7yY-prd-#{^3Pq8jcv9^w4O)bUh8j4lb6f3JJZm*;m+D@?|M6tYrVp%!G(lUxA zr4+Z7P%Pd?v8b5h)*_09TPYS4Qp_))xFw(B<}DOAZKjyFiQ>jQiW@dke0T%J+=nSX zluL2_LloDor(1@f58#iWVzHvxTC` zOwnkfXfX2dmjLtI{NEuw*mxQGZKVF5ga7qSuYg_wy#jg#^a|(|&?}%+ApZ_5a^|U+C-q_4WT>TE?ue|G)nauDrhf|3Mujef|G@A^88p_5XPp z8Y-Ee|1al%#HqK`%j$&MuZGl2)uc=-S71edOsQ6~6q|fozAB%Q56d<3a@j2XO8P#m z$KNZJNa^Bd;!nlziqD8$;wCW>R@VQM@S1Q`XcpEAGXHCc)i=EYdIj_f=oQc_pjSYz zfL?+B?Fs~klU#;x+{K8S%s82G-P7(FA%k#Bmz}|3FUxv4qdMcHC(qMP4#Nk`cLawM zS=Q+cXGXQh>&YXBAge7noWQb91#e_H-S<3Rau9Bo0a+fFH4!`=yy2d5-y<=|Y7Y*( zS=LLzrr>G!1@{y=fZFG?TrBIwU~;g@J?g$d2H@yBg2PUh^#Z;XOm;`zqhvqaBnK+t zU|Hk%ReZ~x?~amwc*d^au$^T+kK6I9Zl61!>_c7iN3XK1lNiG-hGyM90-qyC(crMc zvYt!7jj`*7YnJRq(fQ+&S=Lzk>*=>$7hN|j5C?GR)MEbE!T-SlOyB3BRTg{;otFwe3^18)TGx&p2u(gUv%vTQ8t z=|F$r4dAqc$Su$g6z z_}}v1bB;SNktn=I$TG34C;Wr{x14>>anc1@@EVOQ>$rb|f6!U#>?56Uu0oc9Weumz z`8PN-oTa1#vfx}rEbCa>57OqGoHK*8!}-;Xhq*6*7<}R0V@K0KbKG)rqz$6l@ll5J zcv@lFF~?=cEdp;?v3R9G3|gaIdDv7|u}Ydub*|zhi<#Ag2@$F`P$J zM^oQ(gdF{(6_rBHqYURrYI*9YBhwKga91c)j*l>$M^dHKa)-&0Nt#hPrA57Mw~-pSZ%qjAXE^=IzfAF|SJm638k|Pl&v5o7pH2RydPcoU bs^C644R{~J(ULop&#Hs!8Bz)Nl4<-OI=H}F literal 0 HcmV?d00001 diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 9b4ff66ac1..4083336623 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -19,6 +19,7 @@ import mindspore.dataset as ds from mindspore import log as logger DATASET_FILE = "../data/mindrecord/testGraphData/testdata" +SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" def test_graphdata_getfullneighbor(): @@ -172,6 +173,17 @@ def test_graphdata_generatordataset(): assert i == 40 +def test_graphdata_randomwalk(): + g = ds.GraphData(SOCIAL_DATA_FILE, 1) + nodes = g.get_all_nodes(1) + print(len(nodes)) + assert len(nodes) == 33 + + meta_path = [1 for _ in range(39)] + walks = g.random_walk(nodes, meta_path) + assert walks.shape == (33, 40) + + if __name__ == '__main__': test_graphdata_getfullneighbor() logger.info('test_graphdata_getfullneighbor Ended.\n') @@ -185,3 +197,5 @@ if __name__ == '__main__': logger.info('test_graphdata_graphinfo Ended.\n') test_graphdata_generatordataset() logger.info('test_graphdata_generatordataset Ended.\n') + test_graphdata_randomwalk() + logger.info('test_graphdata_randomwalk Ended.\n')