| @@ -65,9 +65,9 @@ PYBIND_REGISTER( | |||
| }) | |||
| .def("get_sampled_neighbors", | |||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, | |||
| std::vector<gnn::NodeType> neighbor_types) { | |||
| std::vector<gnn::NodeType> neighbor_types, SamplingStrategy strategy) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); | |||
| THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_neg_sampled_neighbors", | |||
| @@ -114,8 +114,15 @@ PYBIND_REGISTER( | |||
| return out; | |||
| })) | |||
| .def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); }) | |||
| .def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); }); | |||
| .def("is_stopped", [](gnn::GraphDataServer &g) { return g.IsStopped(); }); | |||
| })); | |||
| PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) { | |||
| (void)py::enum_<SamplingStrategy>(*m, "SamplingStrategy", py::arithmetic()) | |||
| .value("DE_SAMPLING_RANDOM", SamplingStrategy::kRandom) | |||
| .value("DE_SAMPLING_EDGE_WEIGHT", SamplingStrategy::kEdgeWeight) | |||
| .export_values(); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -71,6 +71,9 @@ enum class NormalizeForm { | |||
| kNfkd, | |||
| }; | |||
| // Possible values for SamplingStrategy | |||
| enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 }; | |||
| // convenience functions for 32bit int bitmask | |||
| inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } | |||
| @@ -35,10 +35,11 @@ class Edge { | |||
| // Constructor | |||
| // @param EdgeIdType id - edge id | |||
| // @param EdgeType type - edge type | |||
| // @param WeightType weight - edge weight | |||
| // @param std::shared_ptr<Node> src_node - source node | |||
| // @param std::shared_ptr<Node> dst_node - destination node | |||
| Edge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node) | |||
| : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} | |||
| Edge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node) | |||
| : id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {} | |||
| virtual ~Edge() = default; | |||
| @@ -48,6 +49,9 @@ class Edge { | |||
| // @return NodeIdType - Returned edge type | |||
| EdgeType type() const { return type_; } | |||
| // @return WeightType - Returned edge weight | |||
| WeightType weight() const { return weight_; } | |||
| // Get the feature of a edge | |||
| // @param FeatureType feature_type - type of feature | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| @@ -77,6 +81,7 @@ class Edge { | |||
| protected: | |||
| EdgeIdType id_; | |||
| EdgeType type_; | |||
| WeightType weight_; | |||
| std::shared_ptr<Node> src_node_; | |||
| std::shared_ptr<Node> dst_node_; | |||
| }; | |||
| @@ -71,6 +71,7 @@ message GnnGraphDataRequestPb { | |||
| repeated int32 number = 4; // samples number | |||
| TensorPb id_tensor = 5; // input ids ,node id or edge id | |||
| GnnRandomWalkPb random_walk = 6; | |||
| int32 strategy = 7; | |||
| } | |||
| message GnnGraphDataResponsePb { | |||
| @@ -76,11 +76,13 @@ class GraphData { | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::SamplingStrategy strategy - Sampling strategy | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status The status code returned | |||
| virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; | |||
| const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy, | |||
| std::shared_ptr<Tensor> *out) = 0; | |||
| // Get negative sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| @@ -95,7 +97,7 @@ class GraphData { | |||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||
| // @param std::vector<NodeType> meta_path - node type of each step | |||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - in out hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status The status code returned | |||
| @@ -137,7 +137,8 @@ Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list | |||
| Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy, | |||
| std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| @@ -151,6 +152,7 @@ Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_ | |||
| for (const auto &type : neighbor_types) { | |||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||
| } | |||
| request.set_strategy(static_cast<google::protobuf::int32>(strategy)); | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| @@ -86,10 +86,12 @@ class GraphDataClient : public GraphData { | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::SamplingStrategy strategy - Sampling strategy | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status The status code returned | |||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||
| const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // Get negative sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| @@ -104,7 +106,7 @@ class GraphDataClient : public GraphData { | |||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||
| // @param std::vector<NodeType> meta_path - node type of each step | |||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - in out hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status The status code returned | |||
| @@ -171,7 +171,8 @@ Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) { | |||
| Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy, | |||
| std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), | |||
| "The sizes of neighbor_nums and neighbor_types are inconsistent."); | |||
| @@ -199,7 +200,7 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li | |||
| std::shared_ptr<Node> node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); | |||
| std::vector<NodeIdType> out; | |||
| RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); | |||
| RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out)); | |||
| neighbors.insert(neighbors.end(), out.begin(), out.end()); | |||
| } | |||
| } | |||
| @@ -80,10 +80,12 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::SamplingStrategy strategy - Sampling strategy | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status The status code returned | |||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||
| const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // Get negative sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| @@ -98,7 +100,7 @@ class GraphDataImpl : public GraphData { | |||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||
| // @param std::vector<NodeType> meta_path - node type of each step | |||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - in out hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status The status code returned | |||
| @@ -194,7 +196,7 @@ class GraphDataImpl : public GraphData { | |||
| std::vector<NodeIdType> node_list_; | |||
| std::vector<NodeType> meta_path_; | |||
| float step_home_param_; // Return hyper parameter. Default is 1.0 | |||
| float step_away_param_; // Inout hyper parameter. Default is 1.0 | |||
| float step_away_param_; // In out hyper parameter. Default is 1.0 | |||
| NodeIdType default_node_; | |||
| int32_t num_walks_; // Number of walks per source. Default is 1 | |||
| @@ -50,7 +50,7 @@ class GraphDataServer { | |||
| enum ServerState state() { return state_; } | |||
| bool IsStoped() { | |||
| bool IsStopped() { | |||
| if (state_ == kGdsStopped) { | |||
| return true; | |||
| } else { | |||
| @@ -78,7 +78,7 @@ grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context, | |||
| } | |||
| break; | |||
| case GraphDataServer::kGdsStopped: | |||
| response->set_error_msg("Stoped"); | |||
| response->set_error_msg("Stopped"); | |||
| break; | |||
| } | |||
| } else { | |||
| @@ -222,8 +222,9 @@ Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *re | |||
| neighbor_types.resize(request->type().size()); | |||
| std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(), | |||
| [](const google::protobuf::int32 type) { return static_cast<NodeType>(type); }); | |||
| SamplingStrategy strategy = static_cast<SamplingStrategy>(request->strategy()); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor)); | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| @@ -39,7 +39,9 @@ GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int | |||
| row_id_(0), | |||
| shard_reader_(nullptr), | |||
| graph_feature_parser_(nullptr), | |||
| keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} | |||
| required_key_( | |||
| {"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}), | |||
| optional_key_({{"weight", false}}) {} | |||
| Status GraphLoader::GetNodesAndEdges() { | |||
| NodeIdMap *n_id_map = &graph_impl_->node_id_map_; | |||
| @@ -62,7 +64,7 @@ Status GraphLoader::GetNodesAndEdges() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); | |||
| RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); | |||
| RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); | |||
| RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight())); | |||
| e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ | |||
| graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); | |||
| dq.pop_front(); | |||
| @@ -95,12 +97,18 @@ Status GraphLoader::InitAndLoad() { | |||
| graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); | |||
| mindrecord::json schema = graph_impl_->data_schema_["schema"]; | |||
| for (const std::string &key : keys_) { | |||
| for (const std::string &key : required_key_) { | |||
| if (schema.find(key) == schema.end()) { | |||
| RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); | |||
| } | |||
| } | |||
| for (auto op_key : optional_key_) { | |||
| if (schema.find(op_key.first) != schema.end()) { | |||
| optional_key_[op_key.first] = true; | |||
| } | |||
| } | |||
| if (graph_impl_->server_mode_) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| int64_t total_blob_size = 0; | |||
| @@ -128,7 +136,11 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec | |||
| DefaultNodeFeatureMap *default_feature) { | |||
| NodeIdType node_id = col_jsn["first_id"]; | |||
| NodeType node_type = static_cast<NodeType>(col_jsn["type"]); | |||
| (*node) = std::make_shared<LocalNode>(node_id, node_type); | |||
| WeightType weight = 1; | |||
| if (optional_key_["weight"]) { | |||
| weight = col_jsn["weight"]; | |||
| } | |||
| (*node) = std::make_shared<LocalNode>(node_id, node_type, weight); | |||
| std::vector<int32_t> indices; | |||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices)); | |||
| if (graph_impl_->server_mode_) { | |||
| @@ -174,9 +186,13 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec | |||
| EdgeIdType edge_id = col_jsn["first_id"]; | |||
| EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]); | |||
| NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; | |||
| std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1); | |||
| std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1); | |||
| (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst); | |||
| WeightType edge_weight = 1; | |||
| if (optional_key_["weight"]) { | |||
| edge_weight = col_jsn["weight"]; | |||
| } | |||
| std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1, 1); | |||
| std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1, 1); | |||
| (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, edge_weight, src, dst); | |||
| std::vector<int32_t> indices; | |||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices)); | |||
| if (graph_impl_->server_mode_) { | |||
| @@ -110,7 +110,8 @@ class GraphLoader { | |||
| std::vector<EdgeFeatureMap> e_feature_maps_; | |||
| std::vector<DefaultNodeFeatureMap> default_node_feature_maps_; | |||
| std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_; | |||
| const std::vector<std::string> keys_; | |||
| const std::vector<std::string> required_key_; | |||
| std::unordered_map<std::string, bool> optional_key_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| @@ -21,8 +21,9 @@ namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node) | |||
| : Edge(id, type, src_node, dst_node) {} | |||
| LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node, | |||
| std::shared_ptr<Node> dst_node) | |||
| : Edge(id, type, weight, src_node, dst_node) {} | |||
| Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||
| auto itr = features_.find(feature_type); | |||
| @@ -34,9 +34,11 @@ class LocalEdge : public Edge { | |||
| // Constructor | |||
| // @param EdgeIdType id - edge id | |||
| // @param EdgeType type - edge type | |||
| // @param WeightType weight - edge weight | |||
| // @param std::shared_ptr<Node> src_node - source node | |||
| // @param std::shared_ptr<Node> dst_node - destination node | |||
| LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node); | |||
| LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node, | |||
| std::shared_ptr<Node> dst_node); | |||
| ~LocalEdge() = default; | |||
| @@ -16,6 +16,7 @@ | |||
| #include "minddata/dataset/engine/gnn/local_node.h" | |||
| #include <algorithm> | |||
| #include <random> | |||
| #include <string> | |||
| #include <utility> | |||
| @@ -26,7 +27,10 @@ namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } | |||
| LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight) | |||
| : Node(id, type, weight), rnd_(GetRandomDevice()) { | |||
| rnd_.seed(GetSeed()); | |||
| } | |||
| Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||
| auto itr = features_.find(feature_type); | |||
| @@ -44,13 +48,13 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType | |||
| auto itr = neighbor_nodes_.find(neighbor_type); | |||
| if (itr != neighbor_nodes_.end()) { | |||
| if (exclude_itself) { | |||
| neighbors.resize(itr->second.size()); | |||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), | |||
| neighbors.resize(itr->second.first.size()); | |||
| std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(), | |||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||
| } else { | |||
| neighbors.resize(itr->second.size() + 1); | |||
| neighbors.resize(itr->second.first.size() + 1); | |||
| neighbors[0] = id_; | |||
| std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, | |||
| std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1, | |||
| [](const std::shared_ptr<Node> node) { return node->id(); }); | |||
| } | |||
| } else { | |||
| @@ -63,8 +67,8 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||
| std::vector<NodeIdType> *out) { | |||
| Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||
| std::vector<NodeIdType> *out) { | |||
| std::vector<NodeIdType> shuffled_id(neighbors.size()); | |||
| std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | |||
| std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); | |||
| @@ -75,14 +79,33 @@ Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> & | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| Status LocalNode::GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, | |||
| const std::vector<WeightType> &weights, int32_t samples_num, | |||
| std::vector<NodeIdType> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(), | |||
| "The number of neighbors does not match the weight."); | |||
| std::discrete_distribution<NodeIdType> discrete_dist(weights.begin(), weights.end()); | |||
| for (int32_t i = 0; i < samples_num; ++i) { | |||
| NodeIdType index = discrete_dist(rnd_); | |||
| out->emplace_back(neighbors[index]->id()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, | |||
| std::vector<NodeIdType> *out_neighbors) { | |||
| std::vector<NodeIdType> neighbors; | |||
| neighbors.reserve(samples_num); | |||
| auto itr = neighbor_nodes_.find(neighbor_type); | |||
| if (itr != neighbor_nodes_.end()) { | |||
| while (neighbors.size() < samples_num) { | |||
| RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); | |||
| if (strategy == SamplingStrategy::kRandom) { | |||
| while (neighbors.size() < samples_num) { | |||
| RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors)); | |||
| } | |||
| } else if (strategy == SamplingStrategy::kEdgeWeight) { | |||
| RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors)); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Invalid strategy"); | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; | |||
| @@ -95,12 +118,15 @@ Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_nu | |||
| return Status::OK(); | |||
| } | |||
| Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node) { | |||
| Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) { | |||
| auto itr = neighbor_nodes_.find(node->type()); | |||
| if (itr != neighbor_nodes_.end()) { | |||
| itr->second.push_back(node); | |||
| itr->second.first.push_back(node); | |||
| itr->second.second.push_back(weight); | |||
| } else { | |||
| neighbor_nodes_[node->type()] = {node}; | |||
| std::vector<std::shared_ptr<Node>> nodes = {node}; | |||
| std::vector<WeightType> weights = {weight}; | |||
| neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/gnn/node.h" | |||
| @@ -33,7 +34,7 @@ class LocalNode : public Node { | |||
| // Constructor | |||
| // @param NodeIdType id - node id | |||
| // @param NodeType type - node type | |||
| LocalNode(NodeIdType id, NodeType type); | |||
| LocalNode(NodeIdType id, NodeType type, WeightType weight); | |||
| ~LocalNode() = default; | |||
| @@ -53,15 +54,16 @@ class LocalNode : public Node { | |||
| // Get the sampled neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||
| // @param SamplingStrategy strategy - Sampling strategy | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status The status code returned | |||
| Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, | |||
| std::vector<NodeIdType> *out_neighbors) override; | |||
| // Add neighbor of node | |||
| // @param std::shared_ptr<Node> node - | |||
| // @return Status The status code returned | |||
| Status AddNeighbor(const std::shared_ptr<Node> &node) override; | |||
| Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override; | |||
| // Update feature of node | |||
| // @param std::shared_ptr<Feature> feature - | |||
| @@ -69,12 +71,16 @@ class LocalNode : public Node { | |||
| Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; | |||
| private: | |||
| Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||
| std::vector<NodeIdType> *out); | |||
| Status GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, | |||
| std::vector<NodeIdType> *out); | |||
| Status GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, | |||
| const std::vector<WeightType> &weights, int32_t samples_num, | |||
| std::vector<NodeIdType> *out); | |||
| std::mt19937 rnd_; | |||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; | |||
| std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_; | |||
| std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| @@ -28,6 +28,7 @@ namespace dataset { | |||
| namespace gnn { | |||
| using NodeType = int8_t; | |||
| using NodeIdType = int32_t; | |||
| using WeightType = float; | |||
| constexpr NodeIdType kDefaultNodeId = -1; | |||
| @@ -36,7 +37,8 @@ class Node { | |||
| // Constructor | |||
| // @param NodeIdType id - node id | |||
| // @param NodeType type - node type | |||
| Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} | |||
| // @param WeightType type - node weight | |||
| Node(NodeIdType id, NodeType type, WeightType weight) : id_(id), type_(type), weight_(weight) {} | |||
| virtual ~Node() = default; | |||
| @@ -46,6 +48,9 @@ class Node { | |||
| // @return NodeIdType - Returned node type | |||
| NodeType type() const { return type_; } | |||
| // @return WeightType - Returned node weight | |||
| WeightType weight() const { return weight_; } | |||
| // Get the feature of a node | |||
| // @param FeatureType feature_type - type of feature | |||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||
| @@ -62,15 +67,16 @@ class Node { | |||
| // Get the sampled neighbors of a node | |||
| // @param NodeType neighbor_type - type of neighbor | |||
| // @param int32_t samples_num - Number of neighbors to be acquired | |||
| // @param SamplingStrategy strategy - Sampling strategy | |||
| // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id | |||
| // @return Status The status code returned | |||
| virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, | |||
| virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, | |||
| std::vector<NodeIdType> *out_neighbors) = 0; | |||
| // Add neighbor of node | |||
| // @param std::shared_ptr<Node> node - | |||
| // @return Status The status code returned | |||
| virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0; | |||
| virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0; | |||
| // Update feature of node | |||
| // @param std::shared_ptr<Feature> feature - | |||
| @@ -80,6 +86,7 @@ class Node { | |||
| protected: | |||
| NodeIdType id_; | |||
| NodeType type_; | |||
| WeightType weight_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| @@ -71,6 +71,9 @@ enum class NormalizeForm { | |||
| kNfkd, | |||
| }; | |||
| // Possible values for SamplingStrategy | |||
| enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 }; | |||
| // convenience functions for 32bit int bitmask | |||
| inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } | |||
| @@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip. | |||
| from ..core import config | |||
| from .cache_client import DatasetCache | |||
| from .datasets import * | |||
| from .graphdata import GraphData | |||
| from .graphdata import GraphData, SamplingStrategy | |||
| from .iterators import * | |||
| from .samplers import * | |||
| from .serializer_deserializer import compare, deserialize, serialize, show | |||
| @@ -18,10 +18,12 @@ and provides operations related to graph data. | |||
| """ | |||
| import atexit | |||
| import time | |||
| from enum import IntEnum | |||
| import numpy as np | |||
| from mindspore._c_dataengine import GraphDataClient | |||
| from mindspore._c_dataengine import GraphDataServer | |||
| from mindspore._c_dataengine import Tensor | |||
| from mindspore._c_dataengine import SamplingStrategy as Sampling | |||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | |||
| check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | |||
| @@ -29,6 +31,17 @@ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_ | |||
| check_gnn_random_walk | |||
| class SamplingStrategy(IntEnum): | |||
| RANDOM = 0 | |||
| EDGE_WEIGHT = 1 | |||
| DE_C_INTER_SAMPLING_STRATEGY = { | |||
| SamplingStrategy.RANDOM: Sampling.DE_SAMPLING_RANDOM, | |||
| SamplingStrategy.EDGE_WEIGHT: Sampling.DE_SAMPLING_EDGE_WEIGHT, | |||
| } | |||
| class GraphData: | |||
| """ | |||
| Reads the graph dataset used for GNN training from the shared file and database. | |||
| @@ -86,7 +99,7 @@ class GraphData: | |||
| dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) | |||
| atexit.register(stop) | |||
| try: | |||
| while self._graph_data.is_stoped() is not True: | |||
| while self._graph_data.is_stopped() is not True: | |||
| time.sleep(1) | |||
| except KeyboardInterrupt: | |||
| raise Exception("Graph data server receives KeyboardInterrupt.") | |||
| @@ -185,7 +198,7 @@ class GraphData: | |||
| return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() | |||
| @check_gnn_get_sampled_neighbors | |||
| def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | |||
| def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM): | |||
| """ | |||
| Get sampled neighbor information. | |||
| @@ -199,6 +212,11 @@ class GraphData: | |||
| node_list (Union[list, numpy.ndarray]): The given list of nodes. | |||
| neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop. | |||
| neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. | |||
| strategy (SamplingStrategy, optional): Sampling strategy (default=SamplingStrategy.RANDOM). | |||
| It can be any of [SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]. | |||
| - SamplingStrategy.RANDOM, random sampling with replacement. | |||
| - SamplingStrategy.EDGE_WEIGHT, sampling with edge weight as probability. | |||
| Returns: | |||
| numpy.ndarray, array of neighbors. | |||
| @@ -215,10 +233,12 @@ class GraphData: | |||
| TypeError: If `neighbor_nums` is not list or ndarray. | |||
| TypeError: If `neighbor_types` is not list or ndarray. | |||
| """ | |||
| if not isinstance(strategy, SamplingStrategy): | |||
| raise TypeError("Wrong input type for strategy, should be enum of 'SamplingStrategy'.") | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server.") | |||
| return self._graph_data.get_sampled_neighbors( | |||
| node_list, neighbor_nums, neighbor_types).as_array() | |||
| node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY[strategy]).as_array() | |||
| @check_gnn_get_neg_sampled_neighbors | |||
| def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): | |||
| @@ -342,7 +362,7 @@ class GraphData: | |||
| target_nodes (list[int]): Start node list in random walk | |||
| meta_path (list[int]): node type for each walk step | |||
| step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0). | |||
| step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0). | |||
| step_away_param (float, optional): in out hyper parameter in node2vec algorithm (Default = 1.0). | |||
| default_node (int, optional): default node if no more neighbors found (Default = -1). | |||
| A default value of -1 indicates that no node is given. | |||
| @@ -1114,7 +1114,7 @@ def check_gnn_get_sampled_neighbors(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs) | |||
| [node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs) | |||
| check_gnn_list_or_ndarray(node_list, 'node_list') | |||
| @@ -37,6 +37,7 @@ class GraphMapSchema: | |||
| "second_id": {"type": "int64"}, | |||
| "third_id": {"type": "int64"}, | |||
| "type": {"type": "int32"}, | |||
| "weight": {"type": "float32"}, | |||
| "attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge | |||
| "node_feature_index": {"type": "int32", "shape": [-1]}, | |||
| "edge_feature_index": {"type": "int32", "shape": [-1]} | |||
| @@ -91,8 +92,11 @@ class GraphMapSchema: | |||
| logger.info("node cannot be None.") | |||
| raise ValueError("node cannot be None.") | |||
| node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], | |||
| "node_feature_index": []} | |||
| node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "weight": 1.0, "attribute": 'n', | |||
| "type": node["type"], "node_feature_index": []} | |||
| if "weight" in node: | |||
| node_graph["weight"] = node["weight"] | |||
| for i in range(self.num_node_features): | |||
| k = i + 1 | |||
| node_field_key = 'feature_' + str(k) | |||
| @@ -129,8 +133,11 @@ class GraphMapSchema: | |||
| logger.info("edge cannot be None.") | |||
| raise ValueError("edge cannot be None.") | |||
| edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', | |||
| "type": edge["type"], "edge_feature_index": []} | |||
| edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "weight": 1.0, | |||
| "attribute": 'e', "type": edge["type"], "edge_feature_index": []} | |||
| if "weight" in edge: | |||
| edge_graph["weight"] = edge["weight"] | |||
| for i in range(self.num_edge_features): | |||
| k = i + 1 | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| @@ -38,6 +39,60 @@ using namespace mindspore::dataset::gnn; | |||
| class MindDataTestGNNGraph : public UT::Common { | |||
| protected: | |||
| MindDataTestGNNGraph() = default; | |||
| using NumNeighborsMap = std::map<NodeIdType, uint32_t>; | |||
| using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>; | |||
| void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) { | |||
| auto shape_vec = neighbors->shape().AsVector(); | |||
| uint32_t num_members = 1; | |||
| for (size_t i = 1; i < shape_vec.size(); ++i) { | |||
| num_members *= shape_vec[i]; | |||
| } | |||
| uint32_t index = 0; | |||
| NodeIdType src_node = 0; | |||
| for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>(); | |||
| ++node_itr, ++index) { | |||
| if (index % num_members == 0) { | |||
| src_node = *node_itr; | |||
| continue; | |||
| } | |||
| auto src_node_itr = node_neighbors.find(src_node); | |||
| if (src_node_itr == node_neighbors.end()) { | |||
| node_neighbors[src_node] = {{*node_itr, 1}}; | |||
| } else { | |||
| auto nei_itr = src_node_itr->second.find(*node_itr); | |||
| if (nei_itr == src_node_itr->second.end()) { | |||
| src_node_itr->second[*node_itr] = 1; | |||
| } else { | |||
| src_node_itr->second[*node_itr] += 1; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights, | |||
| float deviation_ratio = 0.1) { | |||
| EXPECT_EQ(number_neighbors.size(), weights.size()); | |||
| int index = 0; | |||
| uint32_t pre_num = 0; | |||
| WeightType pre_weight = 1; | |||
| for (auto neighbor : number_neighbors) { | |||
| if (pre_num != 0) { | |||
| float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]); | |||
| float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second); | |||
| float target_upper = target_ratio * (1 + deviation_ratio); | |||
| float target_lower = target_ratio * (1 - deviation_ratio); | |||
| MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio) | |||
| << " target_upper:" << std::to_string(target_upper) | |||
| << " target_lower:" << std::to_string(target_lower); | |||
| EXPECT_LE(current_ratio, target_upper); | |||
| EXPECT_GE(current_ratio, target_lower); | |||
| } | |||
| pre_num = neighbor.second; | |||
| pre_weight = weights[index]; | |||
| ++index; | |||
| } | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||
| @@ -131,44 +186,75 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||
| std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); | |||
| std::shared_ptr<Tensor> neighbors; | |||
| s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); | |||
| { | |||
| MS_LOG(INFO) << "Test random sampling."; | |||
| NodeNeighborsMap number_neighbors; | |||
| int count = 0; | |||
| while (count < 1000) { | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); | |||
| ParsingNeighbors(neighbors, number_neighbors); | |||
| ++count; | |||
| } | |||
| CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1}); | |||
| } | |||
| { | |||
| MS_LOG(INFO) << "Test edge weight sampling."; | |||
| NodeNeighborsMap number_neighbors; | |||
| int count = 0; | |||
| while (count < 1000) { | |||
| neighbors.reset(); | |||
| s = | |||
| graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); | |||
| ParsingNeighbors(neighbors, number_neighbors); | |||
| ++count; | |||
| } | |||
| CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8}); | |||
| } | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, | |||
| SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, | |||
| {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); | |||
| {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, | |||
| SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, | |||
| SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors); | |||
| s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); | |||
| s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, | |||
| SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != | |||
| std::string::npos); | |||
| neighbors.reset(); | |||
| s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors); | |||
| s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); | |||
| EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); | |||
| } | |||
| @@ -17,6 +17,7 @@ import pytest | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| from mindspore.dataset.engine import SamplingStrategy | |||
| DATASET_FILE = "../data/mindrecord/testGraphData/testdata" | |||
| SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" | |||
| @@ -97,7 +98,10 @@ def test_graphdata_getsampledneighbors(): | |||
| nodes = g.get_nodes_from_edges(edges) | |||
| assert len(nodes) == 40 | |||
| neighbor = g.get_sampled_neighbors( | |||
| np.unique(nodes[0:21, 0]), [2, 3], [2, 1]) | |||
| np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM) | |||
| assert neighbor.shape == (10, 9) | |||
| neighbor = g.get_sampled_neighbors( | |||
| np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT) | |||
| assert neighbor.shape == (10, 9) | |||
| @@ -20,6 +20,7 @@ from multiprocessing import Process | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| from mindspore.dataset.engine import SamplingStrategy | |||
| DATASET_FILE = "../data/mindrecord/testGraphData/testdata" | |||
| @@ -68,9 +69,9 @@ class GNNGraphDataset(): | |||
| neg_nodes = self.g.get_neg_sampled_neighbors( | |||
| node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) | |||
| nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ | |||
| 2, 2], neighbor_types=[2, 1]) | |||
| neg_nodes_neighbors = self.g.get_sampled_neighbors( | |||
| node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) | |||
| 2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM) | |||
| neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], | |||
| neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT) | |||
| nodes_neighbors_features = self.g.get_node_feature( | |||
| node_list=nodes_neighbors, feature_types=[2, 3]) | |||
| neg_neighbors_features = self.g.get_node_feature( | |||