Browse Source

random walk v1

tags/v0.5.0-beta
Jonathan Yan 5 years ago
parent
commit
87d2c27c7f
15 changed files with 480 additions and 20 deletions
  1. +0
    -0
      example/graph_to_mindrecord/sns/__init__.py
  2. +81
    -0
      example/graph_to_mindrecord/sns/mr_api.py
  3. +10
    -0
      example/graph_to_mindrecord/write_sns.sh
  4. +10
    -3
      mindspore/ccsrc/dataset/api/python_bindings.cc
  5. +197
    -3
      mindspore/ccsrc/dataset/engine/gnn/graph.cc
  6. +57
    -2
      mindspore/ccsrc/dataset/engine/gnn/graph.h
  7. +14
    -6
      mindspore/ccsrc/dataset/engine/gnn/local_node.cc
  8. +2
    -1
      mindspore/ccsrc/dataset/engine/gnn/local_node.h
  9. +2
    -1
      mindspore/ccsrc/dataset/engine/gnn/node.h
  10. +42
    -4
      mindspore/dataset/engine/graphdata.py
  11. +18
    -0
      mindspore/dataset/engine/validators.py
  12. +33
    -0
      tests/ut/cpp/dataset/gnn_graph_test.cc
  13. BIN
      tests/ut/data/mindrecord/testGraphData/sns
  14. BIN
      tests/ut/data/mindrecord/testGraphData/sns.db
  15. +14
    -0
      tests/ut/python/dataset/test_graphdata.py

+ 0
- 0
example/graph_to_mindrecord/sns/__init__.py View File


+ 81
- 0
example/graph_to_mindrecord/sns/mr_api.py View File

@@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord GNN writer.
"""
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
[348, 347], [347, 351], [347, 327], [347, 329], [347, 331],
[347, 335], [347, 341], [347, 345], [347, 346], [346, 335],
[346, 340], [346, 339], [346, 349], [346, 353], [346, 354],
[346, 341], [346, 345], [345, 335], [345, 336], [345, 341],
[344, 338], [344, 342], [343, 332], [343, 338], [343, 342],
[342, 332], [340, 349], [334, 349], [333, 349], [330, 349],
[328, 349], [359, 349], [358, 352], [358, 349], [358, 354],
[358, 356], [357, 350], [357, 354], [357, 356], [356, 350],
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]

# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (0, [], [])
edge_profile = (0, [], [])


def yield_nodes(task_id=0):
"""
Generate node data

Yields:
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
node_list = []
for edge in social_data:
src, dst = edge
if src not in node_list:
node_list.append(src)
if dst not in node_list:
node_list.append(dst)
node_list.sort()
print(node_list)
for node_id in node_list:
node = {'id': node_id, 'type': 1}
yield node


def yield_edges(task_id=0):
"""
Generate edge data

Yields:
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
line_count = 0
for undirected_edge in social_data:
line_count += 1
edge = {
'id': line_count,
'src_id': undirected_edge[0],
'dst_id': undirected_edge[1],
'type': 1}
yield edge
line_count += 1
edge = {
'id': line_count,
'src_id': undirected_edge[1],
'dst_id': undirected_edge[0],
'type': 1}
yield edge

+ 10
- 0
example/graph_to_mindrecord/write_sns.sh View File

@@ -0,0 +1,10 @@
#!/bin/bash
MINDRECORD_PATH=/tmp/sns

rm -f $MINDRECORD_PATH/*

python writer.py --mindrecord_script sns \
--mindrecord_file "$MINDRECORD_PATH/sns" \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 14 \
--mindrecord_page_size_by_bit 15

+ 10
- 3
mindspore/ccsrc/dataset/api/python_bindings.cc View File

@@ -584,9 +584,16 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out; return out;
}) })
.def("graph_info", [](gnn::Graph &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
.def("graph_info",
[](gnn::Graph &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
return out;
})
.def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out; return out;
}); });
} }


+ 197
- 3
mindspore/ccsrc/dataset/engine/gnn/graph.cc View File

@@ -29,7 +29,7 @@ namespace dataset {
namespace gnn { namespace gnn {


Graph::Graph(std::string dataset_file, int32_t num_workers) Graph::Graph(std::string dataset_file, int32_t num_workers)
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) {
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) {
rnd_.seed(GetSeed()); rnd_.seed(GetSeed());
MS_LOG(INFO) << "num_workers:" << num_workers; MS_LOG(INFO) << "num_workers:" << num_workers;
} }
@@ -240,8 +240,13 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N
return Status::OK(); return Status::OK();
} }


Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p,
float q, NodeIdType default_node, std::shared_ptr<Tensor> *out) {
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) {
RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node));
std::vector<std::vector<NodeIdType>> walks;
RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks));
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({walks}, DataType(DataType::DE_INT32), out));
return Status::OK(); return Status::OK();
} }


@@ -386,6 +391,195 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
return Status::OK(); return Status::OK();
} }


Graph::RandomWalkBase::RandomWalkBase(Graph *graph)
: graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}

Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, const NodeIdType default_node,
int32_t num_walks, int32_t num_workers) {
node_list_ = node_list;
if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) +
". The size of input path is " + std::to_string(meta_path.size());
RETURN_STATUS_UNEXPECTED(err_msg);
}
meta_path_ = meta_path;
if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) {
std::string err_msg = "Failed, step_home_param and step_away_param required greater than " +
std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) +
", step_away_param: " + std::to_string(step_away_param);
RETURN_STATUS_UNEXPECTED(err_msg);
}
step_home_param_ = step_home_param;
step_away_param_ = step_away_param;
default_node_ = default_node;
num_walks_ = num_walks;
num_workers_ = num_workers;
return Status::OK();
}

Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) {
// Simulate a random walk starting from start node.
auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector
// walk simulate
while (walk.size() - 1 < meta_path_.size()) {
// current nodE
auto cur_node_id = walk.back();
std::shared_ptr<Node> cur_node;
RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node));

// current neighbors
std::vector<NodeIdType> cur_neighbors;
RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true));
std::sort(cur_neighbors.begin(), cur_neighbors.end());

// break if no neighbors
if (cur_neighbors.empty()) {
break;
}

// walk by the fist node, then by the previous 2 nodes
std::shared_ptr<StochasticIndex> stochastic_index;
if (walk.size() == 1) {
RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index));
} else {
NodeIdType prev_node_id = walk[walk.size() - 2];
RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index));
}
NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)];
walk.push_back(next_node_id);
}

while (walk.size() - 1 < meta_path_.size()) {
walk.push_back(default_node_);
}

*walk_path = std::move(walk);
return Status::OK();
}

Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
// Repeatedly simulate random walks from each node
std::vector<uint32_t> permutation(node_list_.size());
std::iota(permutation.begin(), permutation.end(), 0);
for (int32_t i = 0; i < num_walks_; i++) {
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed));
for (const auto &i_perm : permutation) {
std::vector<NodeIdType> walk;
RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk));
walks->push_back(walk);
}
}
return Status::OK();
}

Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
std::shared_ptr<StochasticIndex> *node_probability) {
// Generate alias nodes
std::shared_ptr<Node> node;
graph_->GetNodeByNodeId(node_id, &node);
std::vector<NodeIdType> neighbors;
RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true));
std::sort(neighbors.begin(), neighbors.end());
auto non_normalized_probability = std::vector<float>(neighbors.size(), 1.0);
*node_probability =
std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
return Status::OK();
}

Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
std::shared_ptr<StochasticIndex> *edge_probability) {
// Get the alias edge setup lists for a given edge.
std::shared_ptr<Node> src_node;
graph_->GetNodeByNodeId(src, &src_node);
std::vector<NodeIdType> src_neighbors;
RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true));

std::shared_ptr<Node> dst_node;
graph_->GetNodeByNodeId(dst, &dst_node);
std::vector<NodeIdType> dst_neighbors;
RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true));

std::sort(dst_neighbors.begin(), dst_neighbors.end());
std::vector<float> non_normalized_probability;
for (const auto &dst_nbr : dst_neighbors) {
if (dst_nbr == src) {
non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
continue;
}
auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr);
if (it != src_neighbors.end()) {
// stay close, this node connect both src and dst
non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight']
} else {
// step far away
non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
}
}

*edge_probability =
std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
return Status::OK();
}

StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) {
uint32_t K = probability.size();
std::vector<int32_t> switch_to_large_index(K, 0);
std::vector<float> weight(K, .0);
std::vector<int32_t> smaller;
std::vector<int32_t> larger;
auto random_device = GetRandomDevice();
std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon);
float accumulate_threshold = 0.0;
for (uint32_t i = 0; i < K; i++) {
float threshold_one = distribution(random_device);
accumulate_threshold += threshold_one;
weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold;
weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i);
}

while ((!smaller.empty()) && (!larger.empty())) {
uint32_t small = smaller.back();
smaller.pop_back();
uint32_t large = larger.back();
larger.pop_back();
switch_to_large_index[small] = large;
weight[large] = weight[large] + weight[small] - 1.0;
weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large);
}
return StochasticIndex(switch_to_large_index, weight);
}

uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) {
auto switch_to_large_index = stochastic_index.first;
auto weight = stochastic_index.second;
const uint32_t size_of_index = switch_to_large_index.size();

auto random_device = GetRandomDevice();
std::uniform_real_distribution<> distribution(0.0, 1.0);

// Generate random integer between [0, K)
uint32_t random_idx = std::floor(distribution(random_device) * size_of_index);

if (distribution(random_device) < weight[random_idx]) {
return random_idx;
}
return switch_to_large_index[random_idx];
}

template <typename T>
std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) {
float sum_probability =
1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0);
if (sum_probability < kGnnEpsilon) {
sum_probability = 1.0;
}
std::vector<float> normalized_probability;
std::transform(non_normalized_probability.begin(), non_normalized_probability.end(),
std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; });
return normalized_probability;
}
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

+ 57
- 2
mindspore/ccsrc/dataset/engine/gnn/graph.h View File

@@ -16,12 +16,14 @@
#ifndef DATASET_ENGINE_GNN_GRAPH_H_ #ifndef DATASET_ENGINE_GNN_GRAPH_H_
#define DATASET_ENGINE_GNN_GRAPH_H_ #define DATASET_ENGINE_GNN_GRAPH_H_


#include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <utility>


#include "dataset/core/tensor.h" #include "dataset/core/tensor.h"
#include "dataset/engine/gnn/graph_loader.h" #include "dataset/engine/gnn/graph_loader.h"
@@ -34,6 +36,10 @@ namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {


const float kGnnEpsilon = 0.0001;
const uint32_t kMaxNumWalks = 80;
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;

struct MetaInfo { struct MetaInfo {
std::vector<NodeType> node_type; std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type; std::vector<EdgeType> edge_type;
@@ -98,8 +104,17 @@ class Graph {
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);


Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q,
NodeIdType default_node, std::shared_ptr<Tensor> *out);
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out);


// Get the feature of a node // Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes // @param std::shared_ptr<Tensor> nodes - List of nodes
@@ -130,6 +145,45 @@ class Graph {
Status Init(); Status Init();


private: private:
class RandomWalkBase {
public:
explicit RandomWalkBase(Graph *graph);

Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
int32_t num_walks = 1, int32_t num_workers = 1);

~RandomWalkBase() = default;

Status SimulateWalk(std::vector<std::vector<NodeIdType>> *walks);

private:
Status Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path);

Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
std::shared_ptr<StochasticIndex> *node_probability);

Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
std::shared_ptr<StochasticIndex> *edge_probability);

static StochasticIndex GenerateProbability(const std::vector<float> &probability);

static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index);

template <typename T>
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);

Graph *graph_;
std::vector<NodeIdType> node_list_;
std::vector<NodeType> meta_path_;
float step_home_param_; // Return hyper parameter. Default is 1.0
float step_away_param_; // Inout hyper parameter. Default is 1.0
NodeIdType default_node_;

int32_t num_walks_; // Number of walks per source. Default is 10
int32_t num_workers_; // The number of worker threads. Default is 1
};

// Load graph data from mindrecord file // Load graph data from mindrecord file
// @return Status - The error code return // @return Status - The error code return
Status LoadNodeAndEdge(); Status LoadNodeAndEdge();
@@ -174,6 +228,7 @@ class Graph {
std::string dataset_file_; std::string dataset_file_;
int32_t num_workers_; // The number of worker threads int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_; std::mt19937 rnd_;
RandomWalkBase random_walk_;


std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;


+ 14
- 6
mindspore/ccsrc/dataset/engine/gnn/local_node.cc View File

@@ -39,17 +39,25 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
} }
} }


Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) {
Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, bool exclude_itself) {
std::vector<NodeIdType> neighbors; std::vector<NodeIdType> neighbors;
auto itr = neighbor_nodes_.find(neighbor_type); auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) { if (itr != neighbor_nodes_.end()) {
neighbors.resize(itr->second.size() + 1);
neighbors[0] = id_;
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
[](const std::shared_ptr<Node> node) { return node->id(); });
if (exclude_itself) {
neighbors.resize(itr->second.size());
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(),
[](const std::shared_ptr<Node> node) { return node->id(); });
} else {
neighbors.resize(itr->second.size() + 1);
neighbors[0] = id_;
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
[](const std::shared_ptr<Node> node) { return node->id(); });
}
} else { } else {
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
neighbors.emplace_back(id_);
if (!exclude_itself) {
neighbors.emplace_back(id_);
}
} }
*out_neighbors = std::move(neighbors); *out_neighbors = std::move(neighbors);
return Status::OK(); return Status::OK();


+ 2
- 1
mindspore/ccsrc/dataset/engine/gnn/local_node.h View File

@@ -47,7 +47,8 @@ class LocalNode : public Node {
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return // @return Status - The error code return
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override;
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) override;


// Get the sampled neighbors of a node // Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor


+ 2
- 1
mindspore/ccsrc/dataset/engine/gnn/node.h View File

@@ -56,7 +56,8 @@ class Node {
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return // @return Status - The error code return
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0;
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) = 0;


// Get the sampled neighbors of a node // Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor


+ 42
- 4
mindspore/dataset/engine/graphdata.py View File

@@ -22,7 +22,7 @@ from mindspore._c_dataengine import Tensor


from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk




class GraphData: class GraphData:
@@ -148,7 +148,8 @@ class GraphData:
TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray.
""" """
return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array()
return self._graph.get_sampled_neighbors(
node_list, neighbor_nums, neighbor_types).as_array()


@check_gnn_get_neg_sampled_neighbors @check_gnn_get_neg_sampled_neighbors
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
@@ -174,7 +175,8 @@ class GraphData:
TypeError: If `neg_neighbor_num` is not integer. TypeError: If `neg_neighbor_num` is not integer.
TypeError: If `neg_neighbor_type` is not integer. TypeError: If `neg_neighbor_type` is not integer.
""" """
return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array()
return self._graph.get_neg_sampled_neighbors(
node_list, neg_neighbor_num, neg_neighbor_type).as_array()


@check_gnn_get_node_feature @check_gnn_get_node_feature
def get_node_feature(self, node_list, feature_types): def get_node_feature(self, node_list, feature_types):
@@ -200,7 +202,10 @@ class GraphData:
""" """
if isinstance(node_list, list): if isinstance(node_list, list):
node_list = np.array(node_list, dtype=np.int32) node_list = np.array(node_list, dtype=np.int32)
return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)]
return [
t.as_array() for t in self._graph.get_node_feature(
Tensor(node_list),
feature_types)]


def graph_info(self): def graph_info(self):
""" """
@@ -212,3 +217,36 @@ class GraphData:
node_feature_type and edge_feature_type. node_feature_type and edge_feature_type.
""" """
return self._graph.graph_info() return self._graph.graph_info()

@check_gnn_random_walk
def random_walk(
self,
target_nodes,
meta_path,
step_home_param=1.0,
step_away_param=1.0,
default_node=-1):
"""
Random walk in nodes.

Args:
target_nodes (list[int]): Start node list in random walk
meta_path (list[int]): node type for each walk step
step_home_param (float): return hyper parameter in node2vec algorithm
step_away_param (float): inout hyper parameter in node2vec algorithm
default_node (int): default node if no more neighbors found

Returns:
numpy.ndarray: array of nodes.

Examples:
>>> import mindspore.dataset as ds
>>> data_graph = ds.GraphData('dataset_file', 2)
>>> nodes = data_graph.random_walk([1,2], [1,2,1,2,1])

Raises:
TypeError: If `target_nodes` is not list or ndarray.
TypeError: If `meta_path` is not list or ndarray.
"""
return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
default_node).as_array()

+ 18
- 0
mindspore/dataset/engine/validators.py View File

@@ -1299,6 +1299,24 @@ def check_gnn_get_neg_sampled_neighbors(method):
return new_method return new_method




def check_gnn_random_walk(method):
"""A wrapper that wrap a parameter checker to the GNN `random_walk` function."""

@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)

# check node_list; required argument
check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes')

# check meta_path; required argument
check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path')

return method(*args, **kwargs)

return new_method


def check_aligned_list(param, param_name, membor_type): def check_aligned_list(param, param_name, membor_type):
"""Check whether the structure of each member of the list is the same.""" """Check whether the structure of each member of the list is the same."""




+ 33
- 0
tests/ut/cpp/dataset/gnn_graph_test.cc View File

@@ -27,6 +27,13 @@
using namespace mindspore::dataset; using namespace mindspore::dataset;
using namespace mindspore::dataset::gnn; using namespace mindspore::dataset::gnn;


#define print_int_vec(_i, _str) \
do { \
std::stringstream ss; \
std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \
MS_LOG(INFO) << _str << " " << ss.str(); \
} while (false)

class MindDataTestGNNGraph : public UT::Common { class MindDataTestGNNGraph : public UT::Common {
protected: protected:
MindDataTestGNNGraph() = default; MindDataTestGNNGraph() = default;
@@ -195,3 +202,29 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors); s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos); EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos);
} }

TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());

MetaInfo meta_info;
s = graph.GetMetaInfo(&meta_info);
EXPECT_TRUE(s.IsOk());

std::shared_ptr<Tensor> nodes;
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
EXPECT_TRUE(s.IsOk());
std::vector<NodeIdType> node_list;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
node_list.push_back(*itr);
}

print_int_vec(node_list, "node list ");
std::vector<NodeType> meta_path(59, 1);
std::shared_ptr<Tensor> walk_path;
s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
}

BIN
tests/ut/data/mindrecord/testGraphData/sns View File


BIN
tests/ut/data/mindrecord/testGraphData/sns.db View File


+ 14
- 0
tests/ut/python/dataset/test_graphdata.py View File

@@ -19,6 +19,7 @@ import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger


DATASET_FILE = "../data/mindrecord/testGraphData/testdata" DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"




def test_graphdata_getfullneighbor(): def test_graphdata_getfullneighbor():
@@ -172,6 +173,17 @@ def test_graphdata_generatordataset():
assert i == 40 assert i == 40




def test_graphdata_randomwalk():
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
nodes = g.get_all_nodes(1)
print(len(nodes))
assert len(nodes) == 33

meta_path = [1 for _ in range(39)]
walks = g.random_walk(nodes, meta_path)
assert walks.shape == (33, 40)


if __name__ == '__main__': if __name__ == '__main__':
test_graphdata_getfullneighbor() test_graphdata_getfullneighbor()
logger.info('test_graphdata_getfullneighbor Ended.\n') logger.info('test_graphdata_getfullneighbor Ended.\n')
@@ -185,3 +197,5 @@ if __name__ == '__main__':
logger.info('test_graphdata_graphinfo Ended.\n') logger.info('test_graphdata_graphinfo Ended.\n')
test_graphdata_generatordataset() test_graphdata_generatordataset()
logger.info('test_graphdata_generatordataset Ended.\n') logger.info('test_graphdata_generatordataset Ended.\n')
test_graphdata_randomwalk()
logger.info('test_graphdata_randomwalk Ended.\n')

Loading…
Cancel
Save