| @@ -820,6 +820,12 @@ void bindGraphData(py::module *m) { | |||||
| THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | ||||
| return out.getRow(); | return out.getRow(); | ||||
| }) | }) | ||||
| .def("get_edge_feature", | |||||
| [](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) { | |||||
| TensorRow out; | |||||
| THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); | |||||
| return out.getRow(); | |||||
| }) | |||||
| .def("graph_info", | .def("graph_info", | ||||
| [](gnn::Graph &g) { | [](gnn::Graph &g) { | ||||
| py::dict out; | py::dict out; | ||||
| @@ -125,13 +125,8 @@ Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::s | |||||
| Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out) { | std::shared_ptr<Tensor> *out) { | ||||
| if (node_list.empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("Input node_list is empty."); | |||||
| } | |||||
| if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { | |||||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||||
| RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); | |||||
| std::vector<std::vector<NodeIdType>> neighbors; | std::vector<std::vector<NodeIdType>> neighbors; | ||||
| size_t max_neighbor_num = 0; | size_t max_neighbor_num = 0; | ||||
| @@ -161,6 +156,14 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::CheckNeighborType(NodeType neighbor_type) { | |||||
| if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { | |||||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | ||||
| const std::vector<NodeIdType> &neighbor_nums, | const std::vector<NodeIdType> &neighbor_nums, | ||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | ||||
| @@ -171,10 +174,7 @@ Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||||
| RETURN_IF_NOT_OK(CheckSamplesNum(num)); | RETURN_IF_NOT_OK(CheckSamplesNum(num)); | ||||
| } | } | ||||
| for (const auto &type : neighbor_types) { | for (const auto &type : neighbor_types) { | ||||
| if (node_type_map_.find(type) == node_type_map_.end()) { | |||||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(CheckNeighborType(type)); | |||||
| } | } | ||||
| std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size()); | std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size()); | ||||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | ||||
| @@ -228,44 +228,36 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | ||||
| RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); | RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); | ||||
| if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) { | |||||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); | |||||
| std::vector<std::vector<NodeIdType>> neighbors_vec; | |||||
| neighbors_vec.resize(node_list.size()); | |||||
| std::vector<std::vector<NodeIdType>> neg_neighbors_vec; | |||||
| neg_neighbors_vec.resize(node_list.size()); | |||||
| for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { | ||||
| std::shared_ptr<Node> node; | std::shared_ptr<Node> node; | ||||
| RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); | RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); | ||||
| std::vector<NodeIdType> neighbors; | std::vector<NodeIdType> neighbors; | ||||
| RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); | RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); | ||||
| std::unordered_set<NodeIdType> exclude_node; | |||||
| std::unordered_set<NodeIdType> exclude_nodes; | |||||
| std::transform(neighbors.begin(), neighbors.end(), | std::transform(neighbors.begin(), neighbors.end(), | ||||
| std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_node, exclude_node.begin()), | |||||
| std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()), | |||||
| [](const NodeIdType node) { return node; }); | [](const NodeIdType node) { return node; }); | ||||
| auto itr = node_type_map_.find(neg_neighbor_type); | |||||
| if (itr == node_type_map_.end()) { | |||||
| std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type]; | |||||
| neg_neighbors_vec[node_idx].emplace_back(node->id()); | |||||
| if (all_nodes.size() > exclude_nodes.size()) { | |||||
| while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { | |||||
| RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), | |||||
| &neg_neighbors_vec[node_idx])); | |||||
| } | |||||
| } else { | } else { | ||||
| neighbors_vec[node_idx].emplace_back(node->id()); | |||||
| if (itr->second.size() > exclude_node.size()) { | |||||
| while (neighbors_vec[node_idx].size() < samples_num + 1) { | |||||
| RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(), | |||||
| &neighbors_vec[node_idx])); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() | |||||
| << " neg_neighbor_type:" << neg_neighbor_type; | |||||
| // If there are no negative neighbors, they are filled with kDefaultNodeId | |||||
| for (int32_t i = 0; i < samples_num; ++i) { | |||||
| neighbors_vec[node_idx].emplace_back(kDefaultNodeId); | |||||
| } | |||||
| MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() | |||||
| << " neg_neighbor_type:" << neg_neighbor_type; | |||||
| // If there are no negative neighbors, they are filled with kDefaultNodeId | |||||
| for (int32_t i = 0; i < samples_num; ++i) { | |||||
| neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out)); | |||||
| RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -280,8 +272,19 @@ Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::ve | |||||
| } | } | ||||
| Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | ||||
| auto itr = default_feature_map_.find(feature_type); | |||||
| if (itr == default_feature_map_.end()) { | |||||
| auto itr = default_node_feature_map_.find(feature_type); | |||||
| if (itr == default_node_feature_map_.end()) { | |||||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| *out_feature = itr->second; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||||
| auto itr = default_edge_feature_map_.find(feature_type); | |||||
| if (itr == default_edge_feature_map_.end()) { | |||||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } else { | } else { | ||||
| @@ -295,7 +298,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||||
| if (!nodes || nodes->Size() == 0) { | if (!nodes || nodes->Size() == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | ||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); | |||||
| TensorRow tensors; | TensorRow tensors; | ||||
| for (const auto &f_type : feature_types) { | for (const auto &f_type : feature_types) { | ||||
| std::shared_ptr<Feature> default_feature; | std::shared_ptr<Feature> default_feature; | ||||
| @@ -340,6 +343,45 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||||
| Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out) { | TensorRow *out) { | ||||
| if (!edges || edges->Size() == 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Input edges is empty"); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); | |||||
| TensorRow tensors; | |||||
| for (const auto &f_type : feature_types) { | |||||
| std::shared_ptr<Feature> default_feature; | |||||
| // If no feature can be obtained, fill in the default value | |||||
| RETURN_IF_NOT_OK(GetEdgeDefaultFeature(f_type, &default_feature)); | |||||
| TensorShape shape(default_feature->Value()->shape()); | |||||
| auto shape_vec = edges->shape().AsVector(); | |||||
| dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>()); | |||||
| shape = shape.PrependDim(size); | |||||
| std::shared_ptr<Tensor> fea_tensor; | |||||
| RETURN_IF_NOT_OK( | |||||
| Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); | |||||
| dsize_t index = 0; | |||||
| for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) { | |||||
| std::shared_ptr<Edge> edge; | |||||
| RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); | |||||
| std::shared_ptr<Feature> feature; | |||||
| if (!edge->GetFeatures(f_type, &feature).IsOk()) { | |||||
| feature = default_feature; | |||||
| } | |||||
| RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); | |||||
| index++; | |||||
| } | |||||
| TensorShape reshape(edges->shape()); | |||||
| for (auto s : default_feature->Value()->shape().AsVector()) { | |||||
| reshape = reshape.AppendDim(s); | |||||
| } | |||||
| RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); | |||||
| fea_tensor->Squeeze(); | |||||
| tensors.push_back(fea_tensor); | |||||
| } | |||||
| *out = std::move(tensors); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -405,7 +447,8 @@ Status Graph::LoadNodeAndEdge() { | |||||
| RETURN_IF_NOT_OK(gl.InitAndLoad()); | RETURN_IF_NOT_OK(gl.InitAndLoad()); | ||||
| // get all maps | // get all maps | ||||
| RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, | RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, | ||||
| &node_feature_map_, &edge_feature_map_, &default_feature_map_)); | |||||
| &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, | |||||
| &default_edge_feature_map_)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -420,18 +463,33 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { | |||||
| auto itr = edge_id_map_.find(id); | |||||
| if (itr == edge_id_map_.end()) { | |||||
| std::string err_msg = "Invalid edge id:" + std::to_string(id); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| *edge = itr->second; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Graph::RandomWalkBase::RandomWalkBase(Graph *graph) | Graph::RandomWalkBase::RandomWalkBase(Graph *graph) | ||||
| : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} | : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} | ||||
| Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | ||||
| float step_home_param, float step_away_param, const NodeIdType default_node, | float step_home_param, float step_away_param, const NodeIdType default_node, | ||||
| int32_t num_walks, int32_t num_workers) { | int32_t num_walks, int32_t num_workers) { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||||
| node_list_ = node_list; | node_list_ = node_list; | ||||
| if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { | if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { | ||||
| std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + | std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + | ||||
| ". The size of input path is " + std::to_string(meta_path.size()); | ". The size of input path is " + std::to_string(meta_path.size()); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| for (const auto &type : meta_path) { | |||||
| RETURN_IF_NOT_OK(graph_->CheckNeighborType(type)); | |||||
| } | |||||
| meta_path_ = meta_path; | meta_path_ = meta_path; | ||||
| if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { | if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { | ||||
| std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + | std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + | ||||
| @@ -500,15 +558,10 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve | |||||
| } | } | ||||
| Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { | Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { | ||||
| // Repeatedly simulate random walks from each node | |||||
| std::vector<uint32_t> permutation(node_list_.size()); | |||||
| std::iota(permutation.begin(), permutation.end(), 0); | |||||
| for (int32_t i = 0; i < num_walks_; i++) { | for (int32_t i = 0; i < num_walks_; i++) { | ||||
| unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); | |||||
| std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed)); | |||||
| for (const auto &i_perm : permutation) { | |||||
| for (const auto &node : node_list_) { | |||||
| std::vector<NodeIdType> walk; | std::vector<NodeIdType> walk; | ||||
| RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk)); | |||||
| RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); | |||||
| walks->push_back(walk); | walks->push_back(walk); | ||||
| } | } | ||||
| } | } | ||||
| @@ -211,12 +211,24 @@ class Graph { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | ||||
| // Get the default feature of a edge | |||||
| // @param FeatureType feature_type - | |||||
| // @param std::shared_ptr<Feature> *out_feature - Returned feature | |||||
| // @return Status - The error code return | |||||
| Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature); | |||||
| // Find node object using node id | // Find node object using node id | ||||
| // @param NodeIdType id - | // @param NodeIdType id - | ||||
| // @param std::shared_ptr<Node> *node - Returned node object | // @param std::shared_ptr<Node> *node - Returned node object | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); | Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node); | ||||
| // Find edge object using edge id | |||||
| // @param EdgeIdType id - | |||||
| // @param std::shared_ptr<Node> *edge - Returned edge object | |||||
| // @return Status - The error code return | |||||
| Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge); | |||||
| // Negative sampling | // Negative sampling | ||||
| // @param std::vector<NodeIdType> &input_data - The data set to be sampled | // @param std::vector<NodeIdType> &input_data - The data set to be sampled | ||||
| // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded | // @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded | ||||
| @@ -228,6 +240,8 @@ class Graph { | |||||
| Status CheckSamplesNum(NodeIdType samples_num); | Status CheckSamplesNum(NodeIdType samples_num); | ||||
| Status CheckNeighborType(NodeType neighbor_type); | |||||
| std::string dataset_file_; | std::string dataset_file_; | ||||
| int32_t num_workers_; // The number of worker threads | int32_t num_workers_; // The number of worker threads | ||||
| std::mt19937 rnd_; | std::mt19937 rnd_; | ||||
| @@ -242,7 +256,8 @@ class Graph { | |||||
| std::unordered_map<NodeType, std::unordered_set<FeatureType>> node_feature_map_; | std::unordered_map<NodeType, std::unordered_set<FeatureType>> node_feature_map_; | ||||
| std::unordered_map<EdgeType, std::unordered_set<FeatureType>> edge_feature_map_; | std::unordered_map<EdgeType, std::unordered_set<FeatureType>> edge_feature_map_; | ||||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_feature_map_; | |||||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_node_feature_map_; | |||||
| std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_edge_feature_map_; | |||||
| }; | }; | ||||
| } // namespace gnn | } // namespace gnn | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -41,7 +41,8 @@ GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) | |||||
| Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, | Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, | ||||
| EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, | EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, | ||||
| EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) { | |||||
| EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, | |||||
| DefaultEdgeFeatureMap *default_edge_feature_map) { | |||||
| for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) { | for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) { | ||||
| while (dq.empty() == false) { | while (dq.empty() == false) { | ||||
| std::shared_ptr<Node> node_ptr = dq.front(); | std::shared_ptr<Node> node_ptr = dq.front(); | ||||
| @@ -70,7 +71,7 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N | |||||
| for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); | for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); | ||||
| for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); | for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); | ||||
| MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map); | |||||
| MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -81,7 +82,8 @@ Status GraphLoader::InitAndLoad() { | |||||
| e_deques_.resize(num_workers_); | e_deques_.resize(num_workers_); | ||||
| n_feature_maps_.resize(num_workers_); | n_feature_maps_.resize(num_workers_); | ||||
| e_feature_maps_.resize(num_workers_); | e_feature_maps_.resize(num_workers_); | ||||
| default_feature_maps_.resize(num_workers_); | |||||
| default_node_feature_maps_.resize(num_workers_); | |||||
| default_edge_feature_maps_.resize(num_workers_); | |||||
| TaskGroup vg; | TaskGroup vg; | ||||
| shard_reader_ = std::make_unique<ShardReader>(); | shard_reader_ = std::make_unique<ShardReader>(); | ||||
| @@ -109,7 +111,7 @@ Status GraphLoader::InitAndLoad() { | |||||
| Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn, | Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn, | ||||
| std::shared_ptr<Node> *node, NodeFeatureMap *feature_map, | std::shared_ptr<Node> *node, NodeFeatureMap *feature_map, | ||||
| DefaultFeatureMap *default_feature) { | |||||
| DefaultNodeFeatureMap *default_feature) { | |||||
| NodeIdType node_id = col_jsn["first_id"]; | NodeIdType node_id = col_jsn["first_id"]; | ||||
| NodeType node_type = static_cast<NodeType>(col_jsn["type"]); | NodeType node_type = static_cast<NodeType>(col_jsn["type"]); | ||||
| (*node) = std::make_shared<LocalNode>(node_id, node_type); | (*node) = std::make_shared<LocalNode>(node_id, node_type); | ||||
| @@ -133,7 +135,7 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec | |||||
| Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn, | Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn, | ||||
| std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map, | std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map, | ||||
| DefaultFeatureMap *default_feature) { | |||||
| DefaultEdgeFeatureMap *default_feature) { | |||||
| EdgeIdType edge_id = col_jsn["first_id"]; | EdgeIdType edge_id = col_jsn["first_id"]; | ||||
| EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]); | EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]); | ||||
| NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; | NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; | ||||
| @@ -214,13 +216,13 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { | |||||
| std::string attr = col_jsn["attribute"]; | std::string attr = col_jsn["attribute"]; | ||||
| if (attr == "n") { | if (attr == "n") { | ||||
| std::shared_ptr<Node> node_ptr; | std::shared_ptr<Node> node_ptr; | ||||
| RETURN_IF_NOT_OK( | |||||
| LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); | |||||
| RETURN_IF_NOT_OK(LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), | |||||
| &default_node_feature_maps_[worker_id])); | |||||
| n_deques_[worker_id].emplace_back(node_ptr); | n_deques_[worker_id].emplace_back(node_ptr); | ||||
| } else if (attr == "e") { | } else if (attr == "e") { | ||||
| std::shared_ptr<Edge> edge_ptr; | std::shared_ptr<Edge> edge_ptr; | ||||
| RETURN_IF_NOT_OK( | |||||
| LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); | |||||
| RETURN_IF_NOT_OK(LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), | |||||
| &default_edge_feature_maps_[worker_id])); | |||||
| e_deques_[worker_id].emplace_back(edge_ptr); | e_deques_[worker_id].emplace_back(edge_ptr); | ||||
| } else { | } else { | ||||
| MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; | MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; | ||||
| @@ -233,7 +235,8 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { | |||||
| } | } | ||||
| void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, | void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, | ||||
| DefaultFeatureMap *default_feature_map) { | |||||
| DefaultNodeFeatureMap *default_node_feature_map, | |||||
| DefaultEdgeFeatureMap *default_edge_feature_map) { | |||||
| for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | ||||
| for (auto &m : n_feature_maps_[wkr_id]) { | for (auto &m : n_feature_maps_[wkr_id]) { | ||||
| for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); | for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); | ||||
| @@ -241,8 +244,11 @@ void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap | |||||
| for (auto &m : e_feature_maps_[wkr_id]) { | for (auto &m : e_feature_maps_[wkr_id]) { | ||||
| for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); | for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); | ||||
| } | } | ||||
| for (auto &m : default_feature_maps_[wkr_id]) { | |||||
| (*default_feature_map)[m.first] = m.second; | |||||
| for (auto &m : default_node_feature_maps_[wkr_id]) { | |||||
| (*default_node_feature_map)[m.first] = m.second; | |||||
| } | |||||
| for (auto &m : default_edge_feature_maps_[wkr_id]) { | |||||
| (*default_edge_feature_map)[m.first] = m.second; | |||||
| } | } | ||||
| } | } | ||||
| n_feature_maps_.clear(); | n_feature_maps_.clear(); | ||||
| @@ -43,7 +43,8 @@ using NodeTypeMap = std::unordered_map<NodeType, std::vector<NodeIdType>>; | |||||
| using EdgeTypeMap = std::unordered_map<EdgeType, std::vector<EdgeIdType>>; | using EdgeTypeMap = std::unordered_map<EdgeType, std::vector<EdgeIdType>>; | ||||
| using NodeFeatureMap = std::unordered_map<NodeType, std::unordered_set<FeatureType>>; | using NodeFeatureMap = std::unordered_map<NodeType, std::unordered_set<FeatureType>>; | ||||
| using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureType>>; | using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureType>>; | ||||
| using DefaultFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | |||||
| using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | |||||
| using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | |||||
| // this class interfaces with the underlying storage format (mindrecord) | // this class interfaces with the underlying storage format (mindrecord) | ||||
| // it returns raw nodes and edges via GetNodesAndEdges | // it returns raw nodes and edges via GetNodesAndEdges | ||||
| @@ -63,7 +64,7 @@ class GraphLoader { | |||||
| // random order. src_node and dst_node in Edge are node_id only with -1 as type. | // random order. src_node and dst_node in Edge are node_id only with -1 as type. | ||||
| // features attached to each node and edge are expected to be filled correctly | // features attached to each node and edge are expected to be filled correctly | ||||
| Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, | Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, | ||||
| DefaultFeatureMap *); | |||||
| DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); | |||||
| private: | private: | ||||
| // | // | ||||
| @@ -77,19 +78,19 @@ class GraphLoader { | |||||
| // @param mindrecord::json &jsn - contains raw data | // @param mindrecord::json &jsn - contains raw data | ||||
| // @param std::shared_ptr<Node> *node - return value | // @param std::shared_ptr<Node> *node - return value | ||||
| // @param NodeFeatureMap *feature_map - | // @param NodeFeatureMap *feature_map - | ||||
| // @param DefaultFeatureMap *default_feature - | |||||
| // @param DefaultNodeFeatureMap *default_feature - | |||||
| // @return Status - the status code | // @return Status - the status code | ||||
| Status LoadNode(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Node> *node, | Status LoadNode(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Node> *node, | ||||
| NodeFeatureMap *feature_map, DefaultFeatureMap *default_feature); | |||||
| NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature); | |||||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | ||||
| // @param mindrecord::json &jsn - contains raw data | // @param mindrecord::json &jsn - contains raw data | ||||
| // @param std::shared_ptr<Edge> *edge - return value, the edge ptr, edge is not yet connected | // @param std::shared_ptr<Edge> *edge - return value, the edge ptr, edge is not yet connected | ||||
| // @param FeatureMap *feature_map | // @param FeatureMap *feature_map | ||||
| // @param DefaultFeatureMap *default_feature - | |||||
| // @param DefaultEdgeFeatureMap *default_feature - | |||||
| // @return Status - the status code | // @return Status - the status code | ||||
| Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge, | Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge, | ||||
| EdgeFeatureMap *feature_map, DefaultFeatureMap *default_feature); | |||||
| EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); | |||||
| // @param std::string key - column name | // @param std::string key - column name | ||||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | ||||
| @@ -108,7 +109,7 @@ class GraphLoader { | |||||
| std::shared_ptr<Tensor> *tensor); | std::shared_ptr<Tensor> *tensor); | ||||
| // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 | // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 | ||||
| void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultFeatureMap *); | |||||
| void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); | |||||
| const int32_t num_workers_; | const int32_t num_workers_; | ||||
| std::atomic_int row_id_; | std::atomic_int row_id_; | ||||
| @@ -118,7 +119,8 @@ class GraphLoader { | |||||
| std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_; | std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_; | ||||
| std::vector<NodeFeatureMap> n_feature_maps_; | std::vector<NodeFeatureMap> n_feature_maps_; | ||||
| std::vector<EdgeFeatureMap> e_feature_maps_; | std::vector<EdgeFeatureMap> e_feature_maps_; | ||||
| std::vector<DefaultFeatureMap> default_feature_maps_; | |||||
| std::vector<DefaultNodeFeatureMap> default_node_feature_maps_; | |||||
| std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_; | |||||
| const std::vector<std::string> keys_; | const std::vector<std::string> keys_; | ||||
| }; | }; | ||||
| } // namespace gnn | } // namespace gnn | ||||
| @@ -22,7 +22,8 @@ from mindspore._c_dataengine import Tensor | |||||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | ||||
| check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ | ||||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk | |||||
| check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \ | |||||
| check_gnn_random_walk | |||||
| class GraphData: | class GraphData: | ||||
| @@ -127,7 +128,13 @@ class GraphData: | |||||
| @check_gnn_get_sampled_neighbors | @check_gnn_get_sampled_neighbors | ||||
| def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | ||||
| """ | """ | ||||
| Get sampled neighbor information, maximum support 6-hop sampling. | |||||
| Get sampled neighbor information. | |||||
| The api supports multi-hop neighbor sampling. That is, the previous sampling result is used as the input of | |||||
| next-hop sampling. A maximum of 6-hop are allowed. | |||||
| The sampling result is tiled into a list in the format of [input node, 1-hop sampling result, | |||||
| 2-hop samling result ...] | |||||
| Args: | Args: | ||||
| node_list (list or numpy.ndarray): The given list of nodes. | node_list (list or numpy.ndarray): The given list of nodes. | ||||
| @@ -207,6 +214,35 @@ class GraphData: | |||||
| Tensor(node_list), | Tensor(node_list), | ||||
| feature_types)] | feature_types)] | ||||
| @check_gnn_get_edge_feature | |||||
| def get_edge_feature(self, edge_list, feature_types): | |||||
| """ | |||||
| Get `feature_types` feature of the edges in `edge_list`. | |||||
| Args: | |||||
| edge_list (list or numpy.ndarray): The given list of edges. | |||||
| feature_types (list or ndarray): The given list of feature types. | |||||
| Returns: | |||||
| numpy.ndarray: array of features. | |||||
| Examples: | |||||
| >>> import mindspore.dataset as ds | |||||
| >>> data_graph = ds.GraphData('dataset_file', 2) | |||||
| >>> edges = data_graph.get_all_edges(0) | |||||
| >>> features = data_graph.get_edge_feature(edges, [1]) | |||||
| Raises: | |||||
| TypeError: If `edge_list` is not list or ndarray. | |||||
| TypeError: If `feature_types` is not list or ndarray. | |||||
| """ | |||||
| if isinstance(edge_list, list): | |||||
| edge_list = np.array(edge_list, dtype=np.int32) | |||||
| return [ | |||||
| t.as_array() for t in self._graph.get_edge_feature( | |||||
| Tensor(edge_list), | |||||
| feature_types)] | |||||
| def graph_info(self): | def graph_info(self): | ||||
| """ | """ | ||||
| Get the meta information of the graph, including the number of nodes, the type of nodes, | Get the meta information of the graph, including the number of nodes, the type of nodes, | ||||
| @@ -797,7 +797,7 @@ def check_gnn_graphdata(method): | |||||
| check_file(dataset_file) | check_file(dataset_file) | ||||
| if num_parallel_workers is not None: | if num_parallel_workers is not None: | ||||
| type_check(num_parallel_workers, (int,), "num_parallel_workers") | |||||
| check_num_parallel_workers(num_parallel_workers) | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -970,6 +970,28 @@ def check_gnn_get_node_feature(method): | |||||
| return new_method | return new_method | ||||
| def check_gnn_get_edge_feature(method): | |||||
| """A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs) | |||||
| type_check(edge_list, (list, np.ndarray), "edge_list") | |||||
| if isinstance(edge_list, list): | |||||
| check_aligned_list(edge_list, 'edge_list', int) | |||||
| elif isinstance(edge_list, np.ndarray): | |||||
| if not edge_list.dtype == np.int32: | |||||
| raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( | |||||
| edge_list, edge_list.dtype)) | |||||
| check_gnn_list_or_ndarray(feature_types, 'feature_types') | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_numpyslicesdataset(method): | def check_numpyslicesdataset(method): | ||||
| """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset).""" | """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset).""" | ||||
| @@ -49,9 +49,10 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) { | |||||
| EdgeTypeMap e_type_map; | EdgeTypeMap e_type_map; | ||||
| NodeFeatureMap n_feature_map; | NodeFeatureMap n_feature_map; | ||||
| EdgeFeatureMap e_feature_map; | EdgeFeatureMap e_feature_map; | ||||
| DefaultFeatureMap default_feature_map; | |||||
| DefaultNodeFeatureMap default_node_feature_map; | |||||
| DefaultEdgeFeatureMap default_edge_feature_map; | |||||
| EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map, | EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map, | ||||
| &default_feature_map) | |||||
| &default_node_feature_map, &default_edge_feature_map) | |||||
| .IsOk()); | .IsOk()); | ||||
| EXPECT_EQ(n_id_map.size(), 20); | EXPECT_EQ(n_id_map.size(), 20); | ||||
| EXPECT_EQ(e_id_map.size(), 40); | EXPECT_EQ(e_id_map.size(), 40); | ||||
| @@ -119,6 +120,17 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||||
| std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(), | std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(), | ||||
| [](const EdgeIdType edge) { return edge; }); | [](const EdgeIdType edge) { return edge; }); | ||||
| TensorRow edge_features; | |||||
| s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features); | |||||
| EXPECT_TRUE(s.IsOk()); | |||||
| EXPECT_TRUE(edge_features[0]->ToString() == | |||||
| "Tensor (shape: <40>, Type: int32)\n" | |||||
| "[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]"); | |||||
| EXPECT_TRUE(edge_features[1]->ToString() == | |||||
| "Tensor (shape: <40>, Type: float32)\n" | |||||
| "[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2." | |||||
| "7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]"); | |||||
| std::shared_ptr<Tensor> nodes; | std::shared_ptr<Tensor> nodes; | ||||
| s = graph.GetNodesFromEdges(edge_list, &nodes); | s = graph.GetNodesFromEdges(edge_list, &nodes); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| @@ -125,7 +125,7 @@ def test_graphdata_graphinfo(): | |||||
| assert graph_info['node_num'] == {1: 10, 2: 10} | assert graph_info['node_num'] == {1: 10, 2: 10} | ||||
| assert graph_info['edge_num'] == {0: 40} | assert graph_info['edge_num'] == {0: 40} | ||||
| assert graph_info['node_feature_type'] == [1, 2, 3, 4] | assert graph_info['node_feature_type'] == [1, 2, 3, 4] | ||||
| assert graph_info['edge_feature_type'] == [] | |||||
| assert graph_info['edge_feature_type'] == [1, 2] | |||||
| class RandomBatchedSampler(ds.Sampler): | class RandomBatchedSampler(ds.Sampler): | ||||
| @@ -204,7 +204,6 @@ def test_graphdata_randomwalkdefault(): | |||||
| logger.info('test randomwalk with default parameters.\n') | logger.info('test randomwalk with default parameters.\n') | ||||
| g = ds.GraphData(SOCIAL_DATA_FILE, 1) | g = ds.GraphData(SOCIAL_DATA_FILE, 1) | ||||
| nodes = g.get_all_nodes(1) | nodes = g.get_all_nodes(1) | ||||
| print(len(nodes)) | |||||
| assert len(nodes) == 33 | assert len(nodes) == 33 | ||||
| meta_path = [1 for _ in range(39)] | meta_path = [1 for _ in range(39)] | ||||
| @@ -219,7 +218,6 @@ def test_graphdata_randomwalk(): | |||||
| logger.info('test random walk with given parameters.\n') | logger.info('test random walk with given parameters.\n') | ||||
| g = ds.GraphData(SOCIAL_DATA_FILE, 1) | g = ds.GraphData(SOCIAL_DATA_FILE, 1) | ||||
| nodes = g.get_all_nodes(1) | nodes = g.get_all_nodes(1) | ||||
| print(len(nodes)) | |||||
| assert len(nodes) == 33 | assert len(nodes) == 33 | ||||
| meta_path = [1 for _ in range(39)] | meta_path = [1 for _ in range(39)] | ||||
| @@ -227,6 +225,18 @@ def test_graphdata_randomwalk(): | |||||
| assert walks.shape == (33, 40) | assert walks.shape == (33, 40) | ||||
| def test_graphdata_getedgefeature(): | |||||
| """ | |||||
| Test get edge feature | |||||
| """ | |||||
| logger.info('test get_edge_feature.\n') | |||||
| g = ds.GraphData(DATASET_FILE) | |||||
| edges = g.get_all_edges(0) | |||||
| features = g.get_edge_feature(edges, [1, 2]) | |||||
| assert features[0].shape == (40,) | |||||
| assert features[1].shape == (40,) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_graphdata_getfullneighbor() | test_graphdata_getfullneighbor() | ||||
| test_graphdata_getnodefeature_input_check() | test_graphdata_getnodefeature_input_check() | ||||
| @@ -236,3 +246,4 @@ if __name__ == '__main__': | |||||
| test_graphdata_generatordataset() | test_graphdata_generatordataset() | ||||
| test_graphdata_randomwalkdefault() | test_graphdata_randomwalkdefault() | ||||
| test_graphdata_randomwalk() | test_graphdata_randomwalk() | ||||
| test_graphdata_getedgefeature() | |||||