| @@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake) | |||||
| include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) | include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) | ||||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | ||||
| SET(MS_BUILD_GRPC 0) | |||||
| if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES) | if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES) | ||||
| SET(MS_BUILD_GRPC 1) | |||||
| endif() | |||||
| if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows") | |||||
| SET(MS_BUILD_GRPC 1) | |||||
| endif() | |||||
| if ("${MS_BUILD_GRPC}") | |||||
| # build dependencies of gRPC | # build dependencies of gRPC | ||||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) | include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) | ||||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) | include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) | ||||
| @@ -83,6 +83,7 @@ endif() | |||||
| if (ENABLE_TDTQUE) | if (ENABLE_TDTQUE) | ||||
| add_dependencies(engine-tdt core) | add_dependencies(engine-tdt core) | ||||
| endif () | endif () | ||||
| ################### Create _c_dataengine Library ###################### | ################### Create _c_dataengine Library ###################### | ||||
| set(submodules | set(submodules | ||||
| $<TARGET_OBJECTS:core> | $<TARGET_OBJECTS:core> | ||||
| @@ -182,3 +183,7 @@ else() | |||||
| set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) | set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) | ||||
| endif () | endif () | ||||
| endif() | endif() | ||||
| if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") | |||||
| target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) | |||||
| endif() | |||||
| @@ -18,83 +18,103 @@ | |||||
| #include "pybind11/stl_bind.h" | #include "pybind11/stl_bind.h" | ||||
| #include "minddata/dataset/api/python/pybind_register.h" | #include "minddata/dataset/api/python/pybind_register.h" | ||||
| #include "minddata/dataset/engine/gnn/graph.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_client.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_server.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| PYBIND_REGISTER( | PYBIND_REGISTER( | ||||
| Graph, 0, ([](const py::module *m) { | Graph, 0, ([](const py::module *m) { | ||||
| (void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph") | |||||
| .def(py::init([](std::string dataset_file, int32_t num_workers) { | |||||
| std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers); | |||||
| THROW_IF_ERROR(g_out->Init()); | |||||
| return g_out; | |||||
| (void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient") | |||||
| .def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode, | |||||
| const std::string &hostname, int32_t port) { | |||||
| std::shared_ptr<gnn::GraphData> out; | |||||
| if (working_mode == "local") { | |||||
| out = std::make_shared<gnn::GraphDataImpl>(dataset_file, num_workers); | |||||
| } else if (working_mode == "client") { | |||||
| out = std::make_shared<gnn::GraphDataClient>(dataset_file, hostname, port); | |||||
| } | |||||
| THROW_IF_ERROR(out->Init()); | |||||
| return out; | |||||
| })) | })) | ||||
| .def("get_all_nodes", | .def("get_all_nodes", | ||||
| [](gnn::Graph &g, gnn::NodeType node_type) { | |||||
| [](gnn::GraphData &g, gnn::NodeType node_type) { | |||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); | THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_all_edges", | .def("get_all_edges", | ||||
| [](gnn::Graph &g, gnn::EdgeType edge_type) { | |||||
| [](gnn::GraphData &g, gnn::EdgeType edge_type) { | |||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); | THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_nodes_from_edges", | .def("get_nodes_from_edges", | ||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) { | |||||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) { | |||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_all_neighbors", | .def("get_all_neighbors", | ||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { | |||||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { | |||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); | THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_sampled_neighbors", | .def("get_sampled_neighbors", | ||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, | |||||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, | |||||
| std::vector<gnn::NodeType> neighbor_types) { | std::vector<gnn::NodeType> neighbor_types) { | ||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); | THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_neg_sampled_neighbors", | .def("get_neg_sampled_neighbors", | ||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, | |||||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, | |||||
| gnn::NodeType neg_neighbor_type) { | gnn::NodeType neg_neighbor_type) { | ||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); | THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("get_node_feature", | .def("get_node_feature", | ||||
| [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { | |||||
| [](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { | |||||
| TensorRow out; | TensorRow out; | ||||
| THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | ||||
| return out.getRow(); | return out.getRow(); | ||||
| }) | }) | ||||
| .def("get_edge_feature", | .def("get_edge_feature", | ||||
| [](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) { | |||||
| [](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) { | |||||
| TensorRow out; | TensorRow out; | ||||
| THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); | THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); | ||||
| return out.getRow(); | return out.getRow(); | ||||
| }) | }) | ||||
| .def("graph_info", | .def("graph_info", | ||||
| [](gnn::Graph &g) { | |||||
| [](gnn::GraphData &g) { | |||||
| py::dict out; | py::dict out; | ||||
| THROW_IF_ERROR(g.GraphInfo(&out)); | THROW_IF_ERROR(g.GraphInfo(&out)); | ||||
| return out; | return out; | ||||
| }) | }) | ||||
| .def("random_walk", | .def("random_walk", | ||||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path, | |||||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path, | |||||
| float step_home_param, float step_away_param, gnn::NodeIdType default_node) { | float step_home_param, float step_away_param, gnn::NodeIdType default_node) { | ||||
| std::shared_ptr<Tensor> out; | std::shared_ptr<Tensor> out; | ||||
| THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); | THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); | ||||
| return out; | return out; | ||||
| }); | |||||
| }) | |||||
| .def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); }); | |||||
| (void)py::class_<gnn::GraphDataServer, std::shared_ptr<gnn::GraphDataServer>>(*m, "GraphDataServer") | |||||
| .def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port, | |||||
| int32_t client_num, bool auto_shutdown) { | |||||
| std::shared_ptr<gnn::GraphDataServer> out; | |||||
| out = | |||||
| std::make_shared<gnn::GraphDataServer>(dataset_file, num_workers, hostname, port, client_num, auto_shutdown); | |||||
| THROW_IF_ERROR(out->Init()); | |||||
| return out; | |||||
| })) | |||||
| .def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); }) | |||||
| .def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); }); | |||||
| })); | })); | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -1,9 +1,29 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(engine-gnn OBJECT | |||||
| graph.cc | |||||
| set(DATASET_ENGINE_GNN_SRC_FILES | |||||
| graph_data_impl.cc | |||||
| graph_data_client.cc | |||||
| graph_data_server.cc | |||||
| graph_loader.cc | graph_loader.cc | ||||
| graph_feature_parser.cc | |||||
| local_node.cc | local_node.cc | ||||
| local_edge.cc | local_edge.cc | ||||
| feature.cc | feature.cc | ||||
| ) | |||||
| ) | |||||
| if (CMAKE_SYSTEM_NAME MATCHES "Windows") | |||||
| add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES}) | |||||
| else() | |||||
| set(DATASET_ENGINE_GNN_SRC_FILES | |||||
| ${DATASET_ENGINE_GNN_SRC_FILES} | |||||
| tensor_proto.cc | |||||
| grpc_async_server.cc | |||||
| graph_data_service_impl.cc | |||||
| graph_shared_memory.cc) | |||||
| ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto") | |||||
| ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto") | |||||
| add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS}) | |||||
| add_dependencies(engine-gnn mindspore::protobuf) | |||||
| endif() | |||||
| @@ -19,7 +19,8 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| namespace gnn { | namespace gnn { | ||||
| Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {} | |||||
| Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory) | |||||
| : type_name_(type_name), value_(value), is_shared_memory_(is_shared_memory) {} | |||||
| } // namespace gnn | } // namespace gnn | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -31,7 +31,7 @@ class Feature { | |||||
| // Constructor | // Constructor | ||||
| // @param FeatureType type_name - feature type | // @param FeatureType type_name - feature type | ||||
| // @param std::shared_ptr<Tensor> value - feature value | // @param std::shared_ptr<Tensor> value - feature value | ||||
| Feature(FeatureType type_name, std::shared_ptr<Tensor> value); | |||||
| Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory = false); | |||||
| ~Feature() = default; | ~Feature() = default; | ||||
| @@ -45,6 +45,7 @@ class Feature { | |||||
| private: | private: | ||||
| FeatureType type_name_; | FeatureType type_name_; | ||||
| std::shared_ptr<Tensor> value_; | std::shared_ptr<Tensor> value_; | ||||
| bool is_shared_memory_; | |||||
| }; | }; | ||||
| } // namespace gnn | } // namespace gnn | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -0,0 +1,103 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package mindspore.dataset; | |||||
| import "gnn_tensor.proto"; | |||||
| message GnnClientRegisterRequestPb { | |||||
| int32 pid = 1; | |||||
| } | |||||
| message GnnFeatureInfoPb { | |||||
| int32 type = 1; | |||||
| TensorPb feature = 2; | |||||
| } | |||||
| message GnnClientRegisterResponsePb { | |||||
| string error_msg = 1; | |||||
| string data_schema = 2; | |||||
| int64 shared_memory_key = 3; | |||||
| int64 shared_memory_size = 4; | |||||
| repeated GnnFeatureInfoPb default_node_feature = 5; | |||||
| repeated GnnFeatureInfoPb default_edge_feature = 6; | |||||
| } | |||||
| message GnnClientUnRegisterRequestPb { | |||||
| int32 pid = 1; | |||||
| } | |||||
| message GnnClientUnRegisterResponsePb { | |||||
| string error_msg = 1; | |||||
| } | |||||
| enum GnnOpName { | |||||
| GET_ALL_NODES = 0; | |||||
| GET_ALL_EDGES = 1; | |||||
| GET_NODES_FROM_EDGES = 2; | |||||
| GET_ALL_NEIGHBORS = 3; | |||||
| GET_SAMPLED_NEIGHBORS = 4; | |||||
| GET_NEG_SAMPLED_NEIGHBORS = 5; | |||||
| RANDOM_WALK = 6; | |||||
| GET_NODE_FEATURE = 7; | |||||
| GET_EDGE_FEATURE = 8; | |||||
| } | |||||
| message GnnRandomWalkPb { | |||||
| float p = 1; | |||||
| float q = 2; | |||||
| int32 default_id = 3; | |||||
| } | |||||
| message GnnGraphDataRequestPb { | |||||
| GnnOpName op_name = 1; | |||||
| repeated int32 id = 2; // node id or edge id | |||||
| repeated int32 type = 3; //node type or edge type or neighbor type or feature type | |||||
| repeated int32 number = 4; // samples number | |||||
| TensorPb id_tensor = 5; // input ids ,node id or edge id | |||||
| GnnRandomWalkPb random_walk = 6; | |||||
| } | |||||
| message GnnGraphDataResponsePb { | |||||
| string error_msg = 1; | |||||
| repeated TensorPb result_data = 2; | |||||
| } | |||||
| message GnnMetaInfoRequestPb { | |||||
| } | |||||
| message GnnNodeEdgeInfoPb { | |||||
| int32 type = 1; | |||||
| int32 num = 2; | |||||
| } | |||||
| message GnnMetaInfoResponsePb { | |||||
| string error_msg = 1; | |||||
| repeated GnnNodeEdgeInfoPb node_info = 2; | |||||
| repeated GnnNodeEdgeInfoPb edge_info = 3; | |||||
| repeated int32 node_feature_type = 4; | |||||
| repeated int32 edge_feature_type = 5; | |||||
| } | |||||
| service GnnGraphData { | |||||
| rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb); | |||||
| rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb); | |||||
| rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb); | |||||
| rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb); | |||||
| } | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package mindspore.dataset; | |||||
| enum DataTypePb { | |||||
| DE_PB_UNKNOWN = 0; | |||||
| DE_PB_BOOL = 1; | |||||
| DE_PB_INT8 = 2; | |||||
| DE_PB_UINT8 = 3; | |||||
| DE_PB_INT16 = 4; | |||||
| DE_PB_UINT16 = 5; | |||||
| DE_PB_INT32 = 6; | |||||
| DE_PB_UINT32 = 7; | |||||
| DE_PB_INT64 = 8; | |||||
| DE_PB_UINT64 = 9; | |||||
| DE_PB_FLOAT16 = 10; | |||||
| DE_PB_FLOAT32 = 11; | |||||
| DE_PB_FLOAT64 = 12; | |||||
| DE_PB_STRING = 13; | |||||
| } | |||||
| message TensorPb { | |||||
| repeated int64 dims = 1; // tensor shape info | |||||
| DataTypePb tensor_type = 2; // tensor content data type | |||||
| bytes data = 3; // tensor data | |||||
| } | |||||
| @@ -0,0 +1,134 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/tensor_row.h" | |||||
| #include "minddata/dataset/engine/gnn/feature.h" | |||||
| #include "minddata/dataset/engine/gnn/node.h" | |||||
| #include "minddata/dataset/engine/gnn/edge.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| struct MetaInfo { | |||||
| std::vector<NodeType> node_type; | |||||
| std::vector<EdgeType> edge_type; | |||||
| std::map<NodeType, NodeIdType> node_num; | |||||
| std::map<EdgeType, EdgeIdType> edge_num; | |||||
| std::vector<FeatureType> node_feature_type; | |||||
| std::vector<FeatureType> edge_feature_type; | |||||
| }; | |||||
| class GraphData { | |||||
| public: | |||||
| // Get all nodes from the graph. | |||||
| // @param NodeType node_type - type of node | |||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||||
| // @return Status - The error code return | |||||
| virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0; | |||||
| // Get all edges from the graph. | |||||
| // @param NodeType edge_type - type of edge | |||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||||
| // @return Status - The error code return | |||||
| virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0; | |||||
| // Get the node id from the edge. | |||||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||||
| // @return Status - The error code return | |||||
| virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | |||||
| // All neighbors of the acquisition node. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | |||||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | |||||
| // is not enough, fill in tensor as -1. | |||||
| // @return Status - The error code return | |||||
| virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||||
| std::shared_ptr<Tensor> *out) = 0; | |||||
| // Get sampled neighbors. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||||
| // @return Status - The error code return | |||||
| virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||||
| const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; | |||||
| // Get negative sampled neighbors. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||||
| // @return Status - The error code return | |||||
| virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0; | |||||
| // Node2vec random walk. | |||||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||||
| // @param std::vector<NodeType> meta_path - node type of each step | |||||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||||
| // @param NodeIdType default_node - default node id | |||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||||
| // @return Status - The error code return | |||||
| virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||||
| std::shared_ptr<Tensor> *out) = 0; | |||||
| // Get the feature of a node | |||||
| // @param std::shared_ptr<Tensor> nodes - List of nodes | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||||
| // does not exist. | |||||
| // @param TensorRow *out - Returned features | |||||
| // @return Status - The error code return | |||||
| virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out) = 0; | |||||
| // Get the feature of a edge | |||||
| // @param std::shared_ptr<Tensor> edges - List of edges | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||||
| // does not exist. | |||||
| // @param Tensor *out - Returned features | |||||
| // @return Status - The error code return | |||||
| virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out) = 0; | |||||
| // Return meta information to python layer | |||||
| virtual Status GraphInfo(py::dict *out) = 0; | |||||
| virtual Status Init() = 0; | |||||
| virtual Status Stop() = 0; | |||||
| }; | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ | |||||
| @@ -0,0 +1,589 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/gnn/graph_data_client.h" | |||||
| #include <unistd.h> | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "grpcpp/grpcpp.h" | |||||
| #endif | |||||
| #include "minddata/dataset/core/data_type.h" | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| GraphDataClient::GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port) | |||||
| : dataset_file_(dataset_file), | |||||
| host_(hostname), | |||||
| port_(port), | |||||
| pid_(0), | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| shared_memory_key_(-1), | |||||
| shared_memory_size_(0), | |||||
| graph_feature_parser_(nullptr), | |||||
| graph_shared_memory_(nullptr), | |||||
| #endif | |||||
| registered_(false) { | |||||
| } | |||||
| GraphDataClient::~GraphDataClient() { (void)Stop(); } | |||||
| Status GraphDataClient::Init() { | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS"); | |||||
| #else | |||||
| if (!registered_) { | |||||
| std::string server_address; | |||||
| server_address = host_ + ":" + std::to_string(port_); | |||||
| MS_LOG(INFO) << "Graph data client starting. address:" << server_address; | |||||
| pid_ = getpid(); | |||||
| grpc::ChannelArguments args; | |||||
| args.SetMaxReceiveMessageSize(-1); | |||||
| std::shared_ptr<grpc::Channel> channel = | |||||
| grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args); | |||||
| stub_ = GnnGraphData::NewStub(channel); | |||||
| Status status = RegisterToServer(); | |||||
| while (status.ToString().find("Initializing") != std::string::npos) { | |||||
| MS_LOG(INFO) << "Graph data server is initializing, please wait."; | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||||
| status = RegisterToServer(); | |||||
| } | |||||
| RETURN_IF_NOT_OK(status); | |||||
| MS_LOG(INFO) << "Graph data client successfully registered with server " << server_address; | |||||
| } | |||||
| RETURN_IF_NOT_OK(InitFeatureParser()); | |||||
| return Status::OK(); | |||||
| #endif | |||||
| } | |||||
| Status GraphDataClient::Stop() { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| if (registered_) { | |||||
| UnRegisterToServer(); | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_ALL_NODES); | |||||
| request.add_type(static_cast<google::protobuf::int32>(node_type)); | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_ALL_EDGES); | |||||
| request.add_type(static_cast<google::protobuf::int32>(edge_type)); | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_NODES_FROM_EDGES); | |||||
| for (const auto &edge_id : edge_list) { | |||||
| request.add_id(static_cast<google::protobuf::int32>(edge_id)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_ALL_NEIGHBORS); | |||||
| for (const auto &node_id : node_list) { | |||||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||||
| } | |||||
| request.add_type(static_cast<google::protobuf::int32>(neighbor_type)); | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||||
| const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_SAMPLED_NEIGHBORS); | |||||
| for (const auto &node_id : node_list) { | |||||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||||
| } | |||||
| for (const auto &num : neighbor_nums) { | |||||
| request.add_number(static_cast<google::protobuf::int32>(num)); | |||||
| } | |||||
| for (const auto &type : neighbor_types) { | |||||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_NEG_SAMPLED_NEIGHBORS); | |||||
| for (const auto &node_id : node_list) { | |||||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||||
| } | |||||
| request.add_number(static_cast<google::protobuf::int32>(samples_num)); | |||||
| request.add_type(static_cast<google::protobuf::int32>(neg_neighbor_type)); | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GraphDataClient::RandomWalk(const std::vector<NodeIdType> &node_list, | |||||
| const std::vector<NodeType> &meta_path, float step_home_param, | |||||
| float step_away_param, NodeIdType default_node, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(RANDOM_WALK); | |||||
| for (const auto &node_id : node_list) { | |||||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||||
| } | |||||
| for (const auto &type : meta_path) { | |||||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||||
| } | |||||
| auto walk_param = request.mutable_random_walk(); | |||||
| walk_param->set_p(step_home_param); | |||||
| walk_param->set_q(step_away_param); | |||||
| walk_param->set_default_id(static_cast<google::protobuf::int32>(default_node)); | |||||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, | |||||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| if (!nodes || nodes->Size() == 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_NODE_FEATURE); | |||||
| for (const auto &type : feature_types) { | |||||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(TensorToPb(nodes, request.mutable_id_tensor())); | |||||
| RETURN_IF_NOT_OK(GetGraphData(request, &response)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(), | |||||
| "The number of feature types returned by the server is wrong"); | |||||
| if (response.result_data().size() > 0) { | |||||
| size_t i = 0; | |||||
| for (const auto &result : response.result_data()) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); | |||||
| std::shared_ptr<Tensor> fea_tensor; | |||||
| RETURN_IF_NOT_OK(ParseNodeFeatureFromMemory(nodes, feature_types[i], tensor, &fea_tensor)); | |||||
| out->emplace_back(std::move(fea_tensor)); | |||||
| ++i; | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, | |||||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| if (!edges || edges->Size() == 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Input edges is empty"); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); | |||||
| GnnGraphDataRequestPb request; | |||||
| GnnGraphDataResponsePb response; | |||||
| request.set_op_name(GET_EDGE_FEATURE); | |||||
| for (const auto &type : feature_types) { | |||||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||||
| } | |||||
| RETURN_IF_NOT_OK(TensorToPb(edges, request.mutable_id_tensor())); | |||||
| RETURN_IF_NOT_OK(GetGraphData(request, &response)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(), | |||||
| "The number of feature types returned by the server is wrong"); | |||||
| if (response.result_data().size() > 0) { | |||||
| size_t i = 0; | |||||
| for (const auto &result : response.result_data()) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); | |||||
| std::shared_ptr<Tensor> fea_tensor; | |||||
| RETURN_IF_NOT_OK(ParseEdgeFeatureFromMemory(edges, feature_types[i], tensor, &fea_tensor)); | |||||
| out->emplace_back(std::move(fea_tensor)); | |||||
| ++i; | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GraphInfo(py::dict *out) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| RETURN_IF_NOT_OK(CheckPid()); | |||||
| void *tag; | |||||
| bool ok; | |||||
| grpc::Status status; | |||||
| grpc::ClientContext ctx; | |||||
| grpc::CompletionQueue cq; | |||||
| GnnMetaInfoRequestPb request; | |||||
| GnnMetaInfoResponsePb response; | |||||
| // One minute timeout | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||||
| ctx.set_deadline(deadline); | |||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnMetaInfoResponsePb>> rpc( | |||||
| stub_->PrepareAsyncGetMetaInfo(&ctx, request, &cq)); | |||||
| rpc->StartCall(); | |||||
| rpc->Finish(&response, &status, &response); | |||||
| { | |||||
| py::gil_scoped_release gil_release; | |||||
| auto success = cq.Next(&tag, &ok); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||||
| } | |||||
| if (status.ok()) { | |||||
| if (response.error_msg() != "Success") { | |||||
| RETURN_STATUS_UNEXPECTED(response.error_msg()); | |||||
| } else { | |||||
| MetaInfo meta_info; | |||||
| for (const auto &node : response.node_info()) { | |||||
| meta_info.node_type.emplace_back(static_cast<NodeType>(node.type())); | |||||
| meta_info.node_num[static_cast<NodeType>(node.type())] = static_cast<NodeIdType>(node.num()); | |||||
| } | |||||
| for (const auto &edge : response.edge_info()) { | |||||
| meta_info.edge_type.emplace_back(static_cast<EdgeType>(edge.type())); | |||||
| meta_info.edge_num[static_cast<EdgeType>(edge.type())] = static_cast<EdgeIdType>(edge.num()); | |||||
| } | |||||
| for (const auto &feature_type : response.node_feature_type()) { | |||||
| meta_info.node_feature_type.emplace_back(static_cast<FeatureType>(feature_type)); | |||||
| } | |||||
| for (const auto &feature_type : response.edge_feature_type()) { | |||||
| meta_info.edge_feature_type.emplace_back(static_cast<FeatureType>(feature_type)); | |||||
| } | |||||
| (*out)["node_type"] = py::cast(meta_info.node_type); | |||||
| (*out)["edge_type"] = py::cast(meta_info.edge_type); | |||||
| (*out)["node_num"] = py::cast(meta_info.node_num); | |||||
| (*out)["edge_num"] = py::cast(meta_info.edge_num); | |||||
| (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); | |||||
| (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); | |||||
| } | |||||
| } else { | |||||
| auto error_code = status.error_code(); | |||||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||||
| } | |||||
| #endif | |||||
| return Status::OK(); | |||||
| } | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| Status GraphDataClient::GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response) { | |||||
| RETURN_IF_NOT_OK(CheckPid()); | |||||
| void *tag; | |||||
| bool ok; | |||||
| grpc::Status status; | |||||
| grpc::ClientContext ctx; | |||||
| grpc::CompletionQueue cq; | |||||
| // One minute timeout | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||||
| ctx.set_deadline(deadline); | |||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnGraphDataResponsePb>> rpc( | |||||
| stub_->PrepareAsyncGetGraphData(&ctx, request, &cq)); | |||||
| rpc->StartCall(); | |||||
| rpc->Finish(response, &status, response); | |||||
| { | |||||
| py::gil_scoped_release gil_release; | |||||
| auto success = cq.Next(&tag, &ok); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == response, "Expect the same tag"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||||
| } | |||||
| if (status.ok()) { | |||||
| if (response->error_msg() != "Success") { | |||||
| RETURN_STATUS_UNEXPECTED(response->error_msg()); | |||||
| } | |||||
| } else { | |||||
| auto error_code = status.error_code(); | |||||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| RETURN_IF_NOT_OK(GetGraphData(request, response)); | |||||
| if (1 == response->result_data().size()) { | |||||
| const TensorPb &result = response->result_data()[0]; | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); | |||||
| *out = std::move(tensor); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type, | |||||
| const std::shared_ptr<Tensor> &memory_tensor, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| std::shared_ptr<Tensor> default_feature; | |||||
| // If no feature can be obtained, fill in the default value | |||||
| RETURN_IF_NOT_OK(GetNodeDefaultFeature(feature_type, &default_feature)); | |||||
| TensorShape shape(default_feature->shape()); | |||||
| auto shape_vec = nodes->shape().AsVector(); | |||||
| dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>()); | |||||
| shape = shape.PrependDim(size); | |||||
| std::shared_ptr<Tensor> fea_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor)); | |||||
| dsize_t index = 0; | |||||
| auto fea_addr_itr = memory_tensor->begin<int64_t>(); | |||||
| for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) { | |||||
| int64_t offset = *fea_addr_itr; | |||||
| fea_addr_itr++; | |||||
| int64_t len = *fea_addr_itr; | |||||
| fea_addr_itr++; | |||||
| if (*node_itr == kDefaultNodeId || offset < 0 || len <= 0) { | |||||
| RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature)); | |||||
| } else { | |||||
| uchar *start_addr_of_index = nullptr; | |||||
| TensorShape remaining({-1}); | |||||
| RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining)); | |||||
| RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len)); | |||||
| } | |||||
| index++; | |||||
| } | |||||
| TensorShape reshape(nodes->shape()); | |||||
| for (auto s : default_feature->shape().AsVector()) { | |||||
| reshape = reshape.AppendDim(s); | |||||
| } | |||||
| RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); | |||||
| fea_tensor->Squeeze(); | |||||
| *out = std::move(fea_tensor); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type, | |||||
| const std::shared_ptr<Tensor> &memory_tensor, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| std::shared_ptr<Tensor> default_feature; | |||||
| // If no feature can be obtained, fill in the default value | |||||
| RETURN_IF_NOT_OK(GetEdgeDefaultFeature(feature_type, &default_feature)); | |||||
| TensorShape shape(default_feature->shape()); | |||||
| auto shape_vec = edges->shape().AsVector(); | |||||
| dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>()); | |||||
| shape = shape.PrependDim(size); | |||||
| std::shared_ptr<Tensor> fea_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor)); | |||||
| dsize_t index = 0; | |||||
| auto fea_addr_itr = memory_tensor->begin<int64_t>(); | |||||
| for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) { | |||||
| int64_t offset = *fea_addr_itr; | |||||
| fea_addr_itr++; | |||||
| int64_t len = *fea_addr_itr; | |||||
| fea_addr_itr++; | |||||
| if (offset < 0 || len <= 0) { | |||||
| RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature)); | |||||
| } else { | |||||
| uchar *start_addr_of_index = nullptr; | |||||
| TensorShape remaining({-1}); | |||||
| RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining)); | |||||
| RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len)); | |||||
| } | |||||
| index++; | |||||
| } | |||||
| TensorShape reshape(edges->shape()); | |||||
| for (auto s : default_feature->shape().AsVector()) { | |||||
| reshape = reshape.AppendDim(s); | |||||
| } | |||||
| RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); | |||||
| fea_tensor->Squeeze(); | |||||
| *out = std::move(fea_tensor); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) { | |||||
| auto itr = default_node_feature_map_.find(feature_type); | |||||
| if (itr == default_node_feature_map_.end()) { | |||||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| *out_feature = itr->second; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) { | |||||
| auto itr = default_edge_feature_map_.find(feature_type); | |||||
| if (itr == default_edge_feature_map_.end()) { | |||||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } else { | |||||
| *out_feature = itr->second; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::RegisterToServer() { | |||||
| RETURN_IF_NOT_OK(CheckPid()); | |||||
| void *tag; | |||||
| bool ok; | |||||
| grpc::Status status; | |||||
| grpc::ClientContext ctx; | |||||
| grpc::CompletionQueue cq; | |||||
| GnnClientRegisterRequestPb request; | |||||
| GnnClientRegisterResponsePb response; | |||||
| request.set_pid(static_cast<google::protobuf::int32>(pid_)); | |||||
| // One minute timeout | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||||
| ctx.set_deadline(deadline); | |||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientRegisterResponsePb>> rpc( | |||||
| stub_->PrepareAsyncClientRegister(&ctx, request, &cq)); | |||||
| rpc->StartCall(); | |||||
| rpc->Finish(&response, &status, &response); | |||||
| { | |||||
| py::gil_scoped_release gil_release; | |||||
| auto success = cq.Next(&tag, &ok); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||||
| } | |||||
| if (status.ok()) { | |||||
| if (response.error_msg() == "Success") { | |||||
| registered_ = true; | |||||
| data_schema_ = mindrecord::json::parse(response.data_schema()); | |||||
| shared_memory_key_ = static_cast<key_t>(response.shared_memory_key()); | |||||
| shared_memory_size_ = response.shared_memory_size(); | |||||
| MS_LOG(INFO) << "Register success, recv data_schema:" << response.data_schema(); | |||||
| for (auto feature_info : response.default_node_feature()) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor)); | |||||
| default_node_feature_map_[feature_info.type()] = tensor; | |||||
| } | |||||
| for (auto feature_info : response.default_edge_feature()) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor)); | |||||
| default_edge_feature_map_[feature_info.type()] = tensor; | |||||
| } | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED(response.error_msg()); | |||||
| } | |||||
| } else { | |||||
| auto error_code = status.error_code(); | |||||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::UnRegisterToServer() { | |||||
| RETURN_IF_NOT_OK(CheckPid()); | |||||
| MS_LOG(INFO) << "Graph data client send unregistered to server "; | |||||
| void *tag; | |||||
| bool ok; | |||||
| grpc::Status status; | |||||
| grpc::ClientContext ctx; | |||||
| grpc::CompletionQueue cq; | |||||
| GnnClientUnRegisterRequestPb request; | |||||
| GnnClientUnRegisterResponsePb response; | |||||
| request.set_pid(static_cast<google::protobuf::int32>(pid_)); | |||||
| // One minute timeout | |||||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||||
| ctx.set_deadline(deadline); | |||||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientUnRegisterResponsePb>> rpc( | |||||
| stub_->PrepareAsyncClientUnRegister(&ctx, request, &cq)); | |||||
| rpc->StartCall(); | |||||
| rpc->Finish(&response, &status, &response); | |||||
| { | |||||
| py::gil_scoped_release gil_release; | |||||
| auto success = cq.Next(&tag, &ok); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||||
| } | |||||
| if (status.ok()) { | |||||
| if (response.error_msg() == "Success") { | |||||
| MS_LOG(INFO) << "Unregister success."; | |||||
| registered_ = false; | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED(response.error_msg()); | |||||
| } | |||||
| } else { | |||||
| auto error_code = status.error_code(); | |||||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataClient::InitFeatureParser() { | |||||
| // get shared memory | |||||
| graph_shared_memory_ = std::make_unique<GraphSharedMemory>(shared_memory_size_, shared_memory_key_); | |||||
| RETURN_IF_NOT_OK(graph_shared_memory_->GetSharedMemory()); | |||||
| // build feature parser | |||||
| graph_feature_parser_ = std::make_unique<GraphFeatureParser>(ShardColumn(data_schema_)); | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,185 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <map> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "proto/gnn_graph_data.grpc.pb.h" | |||||
| #include "proto/gnn_graph_data.pb.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/gnn/graph_data.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_feature_parser.h" | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||||
| #endif | |||||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||||
| #include "minddata/mindrecord/include/shard_column.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| class GraphDataClient : public GraphData { | |||||
| public: | |||||
| // Constructor | |||||
| // @param std::string dataset_file - | |||||
| // @param int32_t num_workers - number of parallel threads | |||||
| GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port); | |||||
| ~GraphDataClient(); | |||||
| Status Init() override; | |||||
| Status Stop() override; | |||||
| // Get all nodes from the graph. | |||||
| // @param NodeType node_type - type of node | |||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||||
| // @return Status - The error code return | |||||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | |||||
| // Get all edges from the graph. | |||||
| // @param NodeType edge_type - type of edge | |||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||||
| // @return Status - The error code return | |||||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | |||||
| // Get the node id from the edge. | |||||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||||
| // @return Status - The error code return | |||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||||
| // All neighbors of the acquisition node. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | |||||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | |||||
| // is not enough, fill in tensor as -1. | |||||
| // @return Status - The error code return | |||||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||||
| std::shared_ptr<Tensor> *out) override; | |||||
| // Get sampled neighbors. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||||
| // @return Status - The error code return | |||||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||||
| // Get negative sampled neighbors. | |||||
| // @param std::vector<NodeType> node_list - List of nodes | |||||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||||
| // @return Status - The error code return | |||||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | |||||
| // Node2vec random walk. | |||||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||||
| // @param std::vector<NodeType> meta_path - node type of each step | |||||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||||
| // @param NodeIdType default_node - default node id | |||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||||
| // @return Status - The error code return | |||||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||||
| std::shared_ptr<Tensor> *out) override; | |||||
| // Get the feature of a node | |||||
| // @param std::shared_ptr<Tensor> nodes - List of nodes | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||||
| // does not exist. | |||||
| // @param TensorRow *out - Returned features | |||||
| // @return Status - The error code return | |||||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out) override; | |||||
| // Get the feature of a edge | |||||
| // @param std::shared_ptr<Tensor> edges - List of edges | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||||
| // does not exist. | |||||
| // @param Tensor *out - Returned features | |||||
| // @return Status - The error code return | |||||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out) override; | |||||
| // Return meta information to python layer | |||||
| Status GraphInfo(py::dict *out) override; | |||||
| private: | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type, | |||||
| const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out); | |||||
| Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type, | |||||
| const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out); | |||||
| Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature); | |||||
| Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature); | |||||
| Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response); | |||||
| Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response, | |||||
| std::shared_ptr<Tensor> *out); | |||||
| Status RegisterToServer(); | |||||
| Status UnRegisterToServer(); | |||||
| Status InitFeatureParser(); | |||||
| Status CheckPid() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(), | |||||
| "Multi-process mode is not supported, please change to use multi-thread"); | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | |||||
| std::string dataset_file_; | |||||
| std::string host_; | |||||
| int32_t port_; | |||||
| int32_t pid_; | |||||
| mindrecord::json data_schema_; | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| std::unique_ptr<GnnGraphData::Stub> stub_; | |||||
| key_t shared_memory_key_; | |||||
| int64_t shared_memory_size_; | |||||
| std::unique_ptr<GraphFeatureParser> graph_feature_parser_; | |||||
| std::unique_ptr<GraphSharedMemory> graph_shared_memory_; | |||||
| std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_; | |||||
| std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_; | |||||
| #endif | |||||
| bool registered_; | |||||
| }; | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ | |||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/gnn/graph.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <functional> | #include <functional> | ||||
| @@ -22,19 +22,25 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/core/tensor_shape.h" | #include "minddata/dataset/core/tensor_shape.h" | ||||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||||
| #include "minddata/dataset/util/random.h" | #include "minddata/dataset/util/random.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace gnn { | namespace gnn { | ||||
| Graph::Graph(std::string dataset_file, int32_t num_workers) | |||||
| : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { | |||||
| GraphDataImpl::GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode) | |||||
| : dataset_file_(dataset_file), | |||||
| num_workers_(num_workers), | |||||
| rnd_(GetRandomDevice()), | |||||
| random_walk_(this), | |||||
| server_mode_(server_mode) { | |||||
| rnd_.seed(GetSeed()); | rnd_.seed(GetSeed()); | ||||
| MS_LOG(INFO) << "num_workers:" << num_workers; | MS_LOG(INFO) << "num_workers:" << num_workers; | ||||
| } | } | ||||
| Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||||
| GraphDataImpl::~GraphDataImpl() {} | |||||
| Status GraphDataImpl::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||||
| auto itr = node_type_map_.find(node_type); | auto itr = node_type_map_.find(node_type); | ||||
| if (itr == node_type_map_.end()) { | if (itr == node_type_map_.end()) { | ||||
| std::string err_msg = "Invalid node type:" + std::to_string(node_type); | std::string err_msg = "Invalid node type:" + std::to_string(node_type); | ||||
| @@ -46,8 +52,8 @@ Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| Status GraphDataImpl::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| if (!type.IsCompatible<T>()) { | if (!type.IsCompatible<T>()) { | ||||
| RETURN_STATUS_UNEXPECTED("Data type not compatible"); | RETURN_STATUS_UNEXPECTED("Data type not compatible"); | ||||
| } | } | ||||
| @@ -72,7 +78,7 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) { | |||||
| Status GraphDataImpl::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) { | |||||
| if (!data || data->empty()) { | if (!data || data->empty()) { | ||||
| RETURN_STATUS_UNEXPECTED("Input data is empty"); | RETURN_STATUS_UNEXPECTED("Input data is empty"); | ||||
| } | } | ||||
| @@ -89,7 +95,7 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||||
| Status GraphDataImpl::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||||
| auto itr = edge_type_map_.find(edge_type); | auto itr = edge_type_map_.find(edge_type); | ||||
| if (itr == edge_type_map_.end()) { | if (itr == edge_type_map_.end()) { | ||||
| std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); | std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); | ||||
| @@ -100,7 +106,7 @@ Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||||
| Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||||
| if (edge_list.empty()) { | if (edge_list.empty()) { | ||||
| RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); | RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); | ||||
| } | } | ||||
| @@ -122,8 +128,8 @@ Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::s | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | ||||
| RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); | RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); | ||||
| @@ -143,7 +149,7 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::CheckSamplesNum(NodeIdType samples_num) { | |||||
| Status GraphDataImpl::CheckSamplesNum(NodeIdType samples_num) { | |||||
| NodeIdType all_nodes_number = | NodeIdType all_nodes_number = | ||||
| std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, | std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, | ||||
| [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); | [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); | ||||
| @@ -155,7 +161,7 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::CheckNeighborType(NodeType neighbor_type) { | |||||
| Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) { | |||||
| if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { | if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { | ||||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); | std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| @@ -163,9 +169,9 @@ Status Graph::CheckNeighborType(NodeType neighbor_type) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||||
| const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||||
| Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||||
| const std::vector<NodeIdType> &neighbor_nums, | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), | CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), | ||||
| "The sizes of neighbor_nums and neighbor_types are inconsistent."); | "The sizes of neighbor_nums and neighbor_types are inconsistent."); | ||||
| @@ -205,8 +211,9 @@ Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data, | |||||
| int32_t samples_num, std::vector<NodeIdType> *out_samples) { | |||||
| Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, | |||||
| const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, | |||||
| std::vector<NodeIdType> *out_samples) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); | CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); | ||||
| std::vector<NodeIdType> shuffled_id(data.size()); | std::vector<NodeIdType> shuffled_id(data.size()); | ||||
| std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | ||||
| @@ -223,8 +230,8 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||||
| Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | ||||
| RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); | RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); | ||||
| RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); | RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); | ||||
| @@ -260,9 +267,9 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| Status GraphDataImpl::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); | RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); | ||||
| std::vector<std::vector<NodeIdType>> walks; | std::vector<std::vector<NodeIdType>> walks; | ||||
| RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); | RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); | ||||
| @@ -270,7 +277,7 @@ Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::ve | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||||
| Status GraphDataImpl::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||||
| auto itr = default_node_feature_map_.find(feature_type); | auto itr = default_node_feature_map_.find(feature_type); | ||||
| if (itr == default_node_feature_map_.end()) { | if (itr == default_node_feature_map_.end()) { | ||||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | ||||
| @@ -281,7 +288,7 @@ Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||||
| Status GraphDataImpl::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||||
| auto itr = default_edge_feature_map_.find(feature_type); | auto itr = default_edge_feature_map_.find(feature_type); | ||||
| if (itr == default_edge_feature_map_.end()) { | if (itr == default_edge_feature_map_.end()) { | ||||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | ||||
| @@ -292,8 +299,8 @@ Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out) { | |||||
| Status GraphDataImpl::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, | |||||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||||
| if (!nodes || nodes->Size() == 0) { | if (!nodes || nodes->Size() == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | ||||
| } | } | ||||
| @@ -339,8 +346,49 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out) { | |||||
| Status GraphDataImpl::GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| if (!nodes || nodes->Size() == 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | |||||
| } | |||||
| TensorShape shape = nodes->shape().AppendDim(2); | |||||
| std::shared_ptr<Tensor> fea_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor)); | |||||
| auto out_fea_itr = fea_tensor->begin<int64_t>(); | |||||
| for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) { | |||||
| if (*node_itr == kDefaultNodeId) { | |||||
| *out_fea_itr = -1; | |||||
| ++out_fea_itr; | |||||
| *out_fea_itr = -1; | |||||
| ++out_fea_itr; | |||||
| } else { | |||||
| std::shared_ptr<Node> node; | |||||
| RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); | |||||
| std::shared_ptr<Feature> feature; | |||||
| if (!node->GetFeatures(type, &feature).IsOk()) { | |||||
| *out_fea_itr = -1; | |||||
| ++out_fea_itr; | |||||
| *out_fea_itr = -1; | |||||
| ++out_fea_itr; | |||||
| } else { | |||||
| for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>(); | |||||
| ++fea_itr) { | |||||
| *out_fea_itr = *fea_itr; | |||||
| ++out_fea_itr; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| fea_tensor->Squeeze(); | |||||
| *out = std::move(fea_tensor); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataImpl::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, | |||||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||||
| if (!edges || edges->Size() == 0) { | if (!edges || edges->Size() == 0) { | ||||
| RETURN_STATUS_UNEXPECTED("Input edges is empty"); | RETURN_STATUS_UNEXPECTED("Input edges is empty"); | ||||
| } | } | ||||
| @@ -382,12 +430,45 @@ Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::ve | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::Init() { | |||||
| Status GraphDataImpl::GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type, | |||||
| std::shared_ptr<Tensor> *out) { | |||||
| if (!edges || edges->Size() == 0) { | |||||
| RETURN_STATUS_UNEXPECTED("Input edges is empty"); | |||||
| } | |||||
| TensorShape shape = edges->shape().AppendDim(2); | |||||
| std::shared_ptr<Tensor> fea_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor)); | |||||
| auto out_fea_itr = fea_tensor->begin<int64_t>(); | |||||
| for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) { | |||||
| std::shared_ptr<Edge> edge; | |||||
| RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); | |||||
| std::shared_ptr<Feature> feature; | |||||
| if (!edge->GetFeatures(type, &feature).IsOk()) { | |||||
| *out_fea_itr = -1; | |||||
| ++out_fea_itr; | |||||
| *out_fea_itr = -1; | |||||
| ++out_fea_itr; | |||||
| } else { | |||||
| for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>(); ++fea_itr) { | |||||
| *out_fea_itr = *fea_itr; | |||||
| ++out_fea_itr; | |||||
| } | |||||
| } | |||||
| } | |||||
| fea_tensor->Squeeze(); | |||||
| *out = std::move(fea_tensor); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataImpl::Init() { | |||||
| RETURN_IF_NOT_OK(LoadNodeAndEdge()); | RETURN_IF_NOT_OK(LoadNodeAndEdge()); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetMetaInfo(MetaInfo *meta_info) { | |||||
| Status GraphDataImpl::GetMetaInfo(MetaInfo *meta_info) { | |||||
| meta_info->node_type.resize(node_type_map_.size()); | meta_info->node_type.resize(node_type_map_.size()); | ||||
| std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), | std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), | ||||
| [](auto itr) { return itr.first; }); | [](auto itr) { return itr.first; }); | ||||
| @@ -427,7 +508,7 @@ Status Graph::GetMetaInfo(MetaInfo *meta_info) { | |||||
| } | } | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| Status Graph::GraphInfo(py::dict *out) { | |||||
| Status GraphDataImpl::GraphInfo(py::dict *out) { | |||||
| MetaInfo meta_info; | MetaInfo meta_info; | ||||
| RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); | RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); | ||||
| (*out)["node_type"] = py::cast(meta_info.node_type); | (*out)["node_type"] = py::cast(meta_info.node_type); | ||||
| @@ -440,18 +521,16 @@ Status Graph::GraphInfo(py::dict *out) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| Status Graph::LoadNodeAndEdge() { | |||||
| GraphLoader gl(dataset_file_, num_workers_); | |||||
| Status GraphDataImpl::LoadNodeAndEdge() { | |||||
| GraphLoader gl(this, dataset_file_, num_workers_, server_mode_); | |||||
| // ask graph_loader to load everything into memory | // ask graph_loader to load everything into memory | ||||
| RETURN_IF_NOT_OK(gl.InitAndLoad()); | RETURN_IF_NOT_OK(gl.InitAndLoad()); | ||||
| // get all maps | // get all maps | ||||
| RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, | |||||
| &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, | |||||
| &default_edge_feature_map_)); | |||||
| RETURN_IF_NOT_OK(gl.GetNodesAndEdges()); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||||
| Status GraphDataImpl::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||||
| auto itr = node_id_map_.find(id); | auto itr = node_id_map_.find(id); | ||||
| if (itr == node_id_map_.end()) { | if (itr == node_id_map_.end()) { | ||||
| std::string err_msg = "Invalid node id:" + std::to_string(id); | std::string err_msg = "Invalid node id:" + std::to_string(id); | ||||
| @@ -462,7 +541,7 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { | |||||
| Status GraphDataImpl::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { | |||||
| auto itr = edge_id_map_.find(id); | auto itr = edge_id_map_.find(id); | ||||
| if (itr == edge_id_map_.end()) { | if (itr == edge_id_map_.end()) { | ||||
| std::string err_msg = "Invalid edge id:" + std::to_string(id); | std::string err_msg = "Invalid edge id:" + std::to_string(id); | ||||
| @@ -473,12 +552,13 @@ Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Graph::RandomWalkBase::RandomWalkBase(Graph *graph) | |||||
| GraphDataImpl::RandomWalkBase::RandomWalkBase(GraphDataImpl *graph) | |||||
| : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} | : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} | ||||
| Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||||
| float step_home_param, float step_away_param, const NodeIdType default_node, | |||||
| int32_t num_walks, int32_t num_workers) { | |||||
| Status GraphDataImpl::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, | |||||
| const std::vector<NodeType> &meta_path, float step_home_param, | |||||
| float step_away_param, const NodeIdType default_node, int32_t num_walks, | |||||
| int32_t num_workers) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | ||||
| node_list_ = node_list; | node_list_ = node_list; | ||||
| if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { | if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { | ||||
| @@ -516,7 +596,7 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) { | |||||
| Status GraphDataImpl::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) { | |||||
| // Simulate a random walk starting from start node. | // Simulate a random walk starting from start node. | ||||
| auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector | auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector | ||||
| // walk simulate | // walk simulate | ||||
| @@ -556,8 +636,8 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { | |||||
| for (int32_t i = 0; i < num_walks_; i++) { | |||||
| Status GraphDataImpl::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { | |||||
| for (int32_t i = 0; i < num_walks_; ++i) { | |||||
| for (const auto &node : node_list_) { | for (const auto &node : node_list_) { | ||||
| std::vector<NodeIdType> walk; | std::vector<NodeIdType> walk; | ||||
| RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); | RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); | ||||
| @@ -567,8 +647,8 @@ Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, | |||||
| std::shared_ptr<StochasticIndex> *node_probability) { | |||||
| Status GraphDataImpl::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, | |||||
| std::shared_ptr<StochasticIndex> *node_probability) { | |||||
| // Generate alias nodes | // Generate alias nodes | ||||
| std::shared_ptr<Node> node; | std::shared_ptr<Node> node; | ||||
| graph_->GetNodeByNodeId(node_id, &node); | graph_->GetNodeByNodeId(node_id, &node); | ||||
| @@ -581,8 +661,9 @@ Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, cons | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, | |||||
| std::shared_ptr<StochasticIndex> *edge_probability) { | |||||
| Status GraphDataImpl::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, | |||||
| uint32_t meta_path_index, | |||||
| std::shared_ptr<StochasticIndex> *edge_probability) { | |||||
| // Get the alias edge setup lists for a given edge. | // Get the alias edge setup lists for a given edge. | ||||
| std::shared_ptr<Node> src_node; | std::shared_ptr<Node> src_node; | ||||
| graph_->GetNodeByNodeId(src, &src_node); | graph_->GetNodeByNodeId(src, &src_node); | ||||
| @@ -616,7 +697,7 @@ Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const No | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) { | |||||
| StochasticIndex GraphDataImpl::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) { | |||||
| uint32_t K = probability.size(); | uint32_t K = probability.size(); | ||||
| std::vector<int32_t> switch_to_large_index(K, 0); | std::vector<int32_t> switch_to_large_index(K, 0); | ||||
| std::vector<float> weight(K, .0); | std::vector<float> weight(K, .0); | ||||
| @@ -644,7 +725,7 @@ StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<flo | |||||
| return StochasticIndex(switch_to_large_index, weight); | return StochasticIndex(switch_to_large_index, weight); | ||||
| } | } | ||||
| uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { | |||||
| uint32_t GraphDataImpl::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { | |||||
| auto switch_to_large_index = stochastic_index.first; | auto switch_to_large_index = stochastic_index.first; | ||||
| auto weight = stochastic_index.second; | auto weight = stochastic_index.second; | ||||
| const uint32_t size_of_index = switch_to_large_index.size(); | const uint32_t size_of_index = switch_to_large_index.size(); | ||||
| @@ -662,7 +743,7 @@ uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) { | |||||
| std::vector<float> GraphDataImpl::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) { | |||||
| float sum_probability = | float sum_probability = | ||||
| 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); | 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); | ||||
| if (sum_probability < kGnnEpsilon) { | if (sum_probability < kGnnEpsilon) { | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -25,13 +25,11 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/tensor_row.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||||
| #include "minddata/dataset/engine/gnn/feature.h" | |||||
| #include "minddata/dataset/engine/gnn/node.h" | |||||
| #include "minddata/dataset/engine/gnn/edge.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data.h" | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||||
| #endif | |||||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001; | |||||
| const uint32_t kMaxNumWalks = 80; | const uint32_t kMaxNumWalks = 80; | ||||
| using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>; | using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>; | ||||
| struct MetaInfo { | |||||
| std::vector<NodeType> node_type; | |||||
| std::vector<EdgeType> edge_type; | |||||
| std::map<NodeType, NodeIdType> node_num; | |||||
| std::map<EdgeType, EdgeIdType> edge_num; | |||||
| std::vector<FeatureType> node_feature_type; | |||||
| std::vector<FeatureType> edge_feature_type; | |||||
| }; | |||||
| class Graph { | |||||
| class GraphDataImpl : public GraphData { | |||||
| public: | public: | ||||
| // Constructor | // Constructor | ||||
| // @param std::string dataset_file - | // @param std::string dataset_file - | ||||
| // @param int32_t num_workers - number of parallel threads | // @param int32_t num_workers - number of parallel threads | ||||
| Graph(std::string dataset_file, int32_t num_workers); | |||||
| GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false); | |||||
| ~Graph() = default; | |||||
| ~GraphDataImpl(); | |||||
| // Get all nodes from the graph. | // Get all nodes from the graph. | ||||
| // @param NodeType node_type - type of node | // @param NodeType node_type - type of node | ||||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | // @param std::shared_ptr<Tensor> *out - Returned nodes id | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out); | |||||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | |||||
| // Get all edges from the graph. | // Get all edges from the graph. | ||||
| // @param NodeType edge_type - type of edge | // @param NodeType edge_type - type of edge | ||||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | // @param std::shared_ptr<Tensor> *out - Returned edge ids | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out); | |||||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | |||||
| // Get the node id from the edge. | // Get the node id from the edge. | ||||
| // @param std::vector<EdgeIdType> edge_list - List of edges | // @param std::vector<EdgeIdType> edge_list - List of edges | ||||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | // @param std::shared_ptr<Tensor> *out - Returned node ids | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out); | |||||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||||
| // All neighbors of the acquisition node. | // All neighbors of the acquisition node. | ||||
| // @param std::vector<NodeType> node_list - List of nodes | // @param std::vector<NodeType> node_list - List of nodes | ||||
| @@ -85,7 +74,7 @@ class Graph { | |||||
| // is not enough, fill in tensor as -1. | // is not enough, fill in tensor as -1. | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | ||||
| std::shared_ptr<Tensor> *out); | |||||
| std::shared_ptr<Tensor> *out) override; | |||||
| // Get sampled neighbors. | // Get sampled neighbors. | ||||
| // @param std::vector<NodeType> node_list - List of nodes | // @param std::vector<NodeType> node_list - List of nodes | ||||
| @@ -94,7 +83,7 @@ class Graph { | |||||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | ||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); | |||||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||||
| // Get negative sampled neighbors. | // Get negative sampled neighbors. | ||||
| // @param std::vector<NodeType> node_list - List of nodes | // @param std::vector<NodeType> node_list - List of nodes | ||||
| @@ -103,7 +92,7 @@ class Graph { | |||||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | ||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); | |||||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | |||||
| // Node2vec random walk. | // Node2vec random walk. | ||||
| // @param std::vector<NodeIdType> node_list - List of nodes | // @param std::vector<NodeIdType> node_list - List of nodes | ||||
| @@ -115,7 +104,7 @@ class Graph { | |||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | ||||
| float step_home_param, float step_away_param, NodeIdType default_node, | float step_home_param, float step_away_param, NodeIdType default_node, | ||||
| std::shared_ptr<Tensor> *out); | |||||
| std::shared_ptr<Tensor> *out) override; | |||||
| // Get the feature of a node | // Get the feature of a node | ||||
| // @param std::shared_ptr<Tensor> nodes - List of nodes | // @param std::shared_ptr<Tensor> nodes - List of nodes | ||||
| @@ -124,16 +113,22 @@ class Graph { | |||||
| // @param TensorRow *out - Returned features | // @param TensorRow *out - Returned features | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | ||||
| TensorRow *out); | |||||
| TensorRow *out) override; | |||||
| Status GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type, | |||||
| std::shared_ptr<Tensor> *out); | |||||
| // Get the feature of a edge | // Get the feature of a edge | ||||
| // @param std::shared_ptr<Tensor> edget - List of edges | |||||
| // @param std::shared_ptr<Tensor> edges - List of edges | |||||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | ||||
| // does not exist. | // does not exist. | ||||
| // @param Tensor *out - Returned features | // @param Tensor *out - Returned features | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out); | |||||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||||
| TensorRow *out) override; | |||||
| Status GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type, | |||||
| std::shared_ptr<Tensor> *out); | |||||
| // Get meta information of graph | // Get meta information of graph | ||||
| // @param MetaInfo *meta_info - Returned meta information | // @param MetaInfo *meta_info - Returned meta information | ||||
| @@ -142,15 +137,34 @@ class Graph { | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| // Return meta information to python layer | // Return meta information to python layer | ||||
| Status GraphInfo(py::dict *out); | |||||
| Status GraphInfo(py::dict *out) override; | |||||
| #endif | #endif | ||||
| Status Init(); | |||||
| const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() { | |||||
| return &default_node_feature_map_; | |||||
| } | |||||
| const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() { | |||||
| return &default_edge_feature_map_; | |||||
| } | |||||
| Status Init() override; | |||||
| Status Stop() override { return Status::OK(); } | |||||
| std::string GetDataSchema() { return data_schema_.dump(); } | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| key_t GetSharedMemoryKey() { return graph_shared_memory_->memory_key(); } | |||||
| int64_t GetSharedMemorySize() { return graph_shared_memory_->memory_size(); } | |||||
| #endif | |||||
| private: | private: | ||||
| friend class GraphLoader; | |||||
| class RandomWalkBase { | class RandomWalkBase { | ||||
| public: | public: | ||||
| explicit RandomWalkBase(Graph *graph); | |||||
| explicit RandomWalkBase(GraphDataImpl *graph); | |||||
| Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | ||||
| float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, | float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, | ||||
| @@ -176,7 +190,7 @@ class Graph { | |||||
| template <typename T> | template <typename T> | ||||
| std::vector<float> Normalize(const std::vector<T> &non_normalized_probability); | std::vector<float> Normalize(const std::vector<T> &non_normalized_probability); | ||||
| Graph *graph_; | |||||
| GraphDataImpl *graph_; | |||||
| std::vector<NodeIdType> node_list_; | std::vector<NodeIdType> node_list_; | ||||
| std::vector<NodeType> meta_path_; | std::vector<NodeType> meta_path_; | ||||
| float step_home_param_; // Return hyper parameter. Default is 1.0 | float step_home_param_; // Return hyper parameter. Default is 1.0 | ||||
| @@ -248,7 +262,11 @@ class Graph { | |||||
| int32_t num_workers_; // The number of worker threads | int32_t num_workers_; // The number of worker threads | ||||
| std::mt19937 rnd_; | std::mt19937 rnd_; | ||||
| RandomWalkBase random_walk_; | RandomWalkBase random_walk_; | ||||
| mindrecord::json data_schema_; | |||||
| bool server_mode_; | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| std::unique_ptr<GraphSharedMemory> graph_shared_memory_; | |||||
| #endif | |||||
| std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; | std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; | ||||
| std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; | std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; | ||||
| @@ -264,4 +282,4 @@ class Graph { | |||||
| } // namespace gnn | } // namespace gnn | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ | |||||
| @@ -0,0 +1,133 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/gnn/graph_data_server.h" | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include <iterator> | |||||
| #include <numeric> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/core/tensor_shape.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, | |||||
| int32_t port, int32_t client_num, bool auto_shutdown) | |||||
| : dataset_file_(dataset_file), | |||||
| num_workers_(num_workers), | |||||
| client_num_(client_num), | |||||
| max_connected_client_num_(0), | |||||
| auto_shutdown_(auto_shutdown), | |||||
| state_(kGdsUninit) { | |||||
| tg_ = std::make_unique<TaskGroup>(); | |||||
| graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true); | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get()); | |||||
| async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get()); | |||||
| #endif | |||||
| } | |||||
| Status GraphDataServer::Init() { | |||||
| #if defined(_WIN32) || defined(_WIN64) | |||||
| RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS"); | |||||
| #else | |||||
| set_state(kGdsInitializing); | |||||
| RETURN_IF_NOT_OK(async_server_->Run()); | |||||
| // RETURN_IF_NOT_OK(InitGraphDataImpl()); | |||||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this))); | |||||
| for (int32_t i = 0; i < num_workers_; ++i) { | |||||
| RETURN_IF_NOT_OK( | |||||
| tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this))); | |||||
| } | |||||
| if (auto_shutdown_) { | |||||
| RETURN_IF_NOT_OK( | |||||
| tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this))); | |||||
| } | |||||
| return Status::OK(); | |||||
| #endif | |||||
| } | |||||
| Status GraphDataServer::InitGraphDataImpl() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| Status s = graph_data_impl_->Init(); | |||||
| if (s.IsOk()) { | |||||
| set_state(kGdsRunning); | |||||
| } else { | |||||
| (void)Stop(); | |||||
| } | |||||
| return s; | |||||
| } | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| Status GraphDataServer::StartAsyncRpcService() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| RETURN_IF_NOT_OK(async_server_->HandleRequest()); | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | |||||
| Status GraphDataServer::JudgeAutoShutdownServer() { | |||||
| TaskManager::FindMe()->Post(); | |||||
| while (true) { | |||||
| if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) { | |||||
| MS_LOG(INFO) << "All clients have been unregister, automatically exit the server."; | |||||
| RETURN_IF_NOT_OK(Stop()); | |||||
| break; | |||||
| } | |||||
| if (state_ == kGdsStopped) { | |||||
| break; | |||||
| } | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(1000)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServer::Stop() { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| async_server_->Stop(); | |||||
| #endif | |||||
| set_state(kGdsStopped); | |||||
| graph_data_impl_.reset(); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServer::ClientRegister(int32_t pid) { | |||||
| std::unique_lock<std::mutex> lck(mutex_); | |||||
| MS_LOG(INFO) << "client register pid:" << std::to_string(pid); | |||||
| client_pid_.emplace(pid); | |||||
| if (client_pid_.size() > max_connected_client_num_) { | |||||
| max_connected_client_num_ = client_pid_.size(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServer::ClientUnRegister(int32_t pid) { | |||||
| std::unique_lock<std::mutex> lck(mutex_); | |||||
| auto itr = client_pid_.find(pid); | |||||
| if (itr != client_pid_.end()) { | |||||
| client_pid_.erase(itr); | |||||
| MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,196 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <string> | |||||
| #include <unordered_set> | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "grpcpp/grpcpp.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_service_impl.h" | |||||
| #include "minddata/dataset/engine/gnn/grpc_async_server.h" | |||||
| #endif | |||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| class GraphDataImpl; | |||||
| class GraphDataServer { | |||||
| public: | |||||
| enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; | |||||
| GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port, | |||||
| int32_t client_num, bool auto_shutdown); | |||||
| ~GraphDataServer() = default; | |||||
| Status Init(); | |||||
| Status Stop(); | |||||
| Status ClientRegister(int32_t pid); | |||||
| Status ClientUnRegister(int32_t pid); | |||||
| enum ServerState state() { return state_; } | |||||
| bool IsStoped() { | |||||
| if (state_ == kGdsStopped) { | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| private: | |||||
| void set_state(enum ServerState state) { state_ = state; } | |||||
| Status InitGraphDataImpl(); | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| Status StartAsyncRpcService(); | |||||
| #endif | |||||
| Status JudgeAutoShutdownServer(); | |||||
| std::string dataset_file_; | |||||
| int32_t num_workers_; // The number of worker threads | |||||
| int32_t client_num_; | |||||
| int32_t max_connected_client_num_; | |||||
| bool auto_shutdown_; | |||||
| enum ServerState state_; | |||||
| std::unique_ptr<TaskGroup> tg_; // Class for worker management | |||||
| std::unique_ptr<GraphDataImpl> graph_data_impl_; | |||||
| std::unordered_set<int32_t> client_pid_; | |||||
| std::mutex mutex_; | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| std::unique_ptr<GraphDataServiceImpl> service_impl_; | |||||
| std::unique_ptr<GrpcAsyncServer> async_server_; | |||||
| #endif | |||||
| }; | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| class UntypedCall { | |||||
| public: | |||||
| virtual ~UntypedCall() {} | |||||
| virtual Status operator()() = 0; | |||||
| }; | |||||
| template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage> | |||||
| class CallData : public UntypedCall { | |||||
| public: | |||||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | |||||
| using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *, | |||||
| grpc::ServerAsyncResponseWriter<ResponseMessage> *, | |||||
| grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *); | |||||
| using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *, | |||||
| ResponseMessage *); | |||||
| CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq, | |||||
| EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) | |||||
| : status_(STATE::CREATE), | |||||
| service_impl_(service_impl), | |||||
| async_service_(async_service), | |||||
| cq_(cq), | |||||
| enqueue_function_(enqueue_function), | |||||
| handle_request_function_(handle_request_function), | |||||
| responder_(&ctx_) {} | |||||
| ~CallData() = default; | |||||
| static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq, | |||||
| EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) { | |||||
| auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>( | |||||
| service_impl, async_service, cq, enqueue_function, handle_request_function); | |||||
| RETURN_IF_NOT_OK((*call)()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status operator()() { | |||||
| if (status_ == STATE::CREATE) { | |||||
| status_ = STATE::PROCESS; | |||||
| (async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this); | |||||
| } else if (status_ == STATE::PROCESS) { | |||||
| EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_); | |||||
| status_ = STATE::FINISH; | |||||
| // new CallData(service_, cq_, this->s_type_); | |||||
| grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_); | |||||
| responder_.Finish(response_, s, this); | |||||
| } else { | |||||
| GPR_ASSERT(status_ == STATE::FINISH); | |||||
| delete this; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | |||||
| STATE status_; | |||||
| ServiceImpl *service_impl_; | |||||
| AsyncService *async_service_; | |||||
| grpc::ServerCompletionQueue *cq_; | |||||
| EnqueueFunction enqueue_function_; | |||||
| HandleRequestFunction handle_request_function_; | |||||
| grpc::ServerContext ctx_; | |||||
| grpc::ServerAsyncResponseWriter<ResponseMessage> responder_; | |||||
| RequestMessage request_; | |||||
| ResponseMessage response_; | |||||
| }; | |||||
| #define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \ | |||||
| do { \ | |||||
| Status s = \ | |||||
| CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \ | |||||
| service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \ | |||||
| &gnn::GraphDataServiceImpl::method); \ | |||||
| RETURN_IF_NOT_OK(s); \ | |||||
| } while (0) | |||||
| class GraphDataGrpcServer : public GrpcAsyncServer { | |||||
| public: | |||||
| GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl) | |||||
| : GrpcAsyncServer(host, port), service_impl_(service_impl) {} | |||||
| Status RegisterService(grpc::ServerBuilder *builder) { | |||||
| builder->RegisterService(&svc_); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status EnqueueRequest() { | |||||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb, | |||||
| GnnClientRegisterResponsePb); | |||||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb, | |||||
| GnnClientUnRegisterResponsePb); | |||||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb); | |||||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ProcessRequest(void *tag) { | |||||
| auto rq = static_cast<UntypedCall *>(tag); | |||||
| RETURN_IF_NOT_OK((*rq)()); | |||||
| return Status::OK(); | |||||
| } | |||||
| private: | |||||
| GraphDataServiceImpl *service_impl_; | |||||
| GnnGraphData::AsyncService svc_; | |||||
| }; | |||||
| #endif | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ | |||||
| @@ -0,0 +1,299 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/gnn/graph_data_service_impl.h" | |||||
| #include <algorithm> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_server.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| using pFunction = Status (GraphDataServiceImpl::*)(const GnnGraphDataRequestPb *, GnnGraphDataResponsePb *); | |||||
| static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = { | |||||
| {GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes}, | |||||
| {GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges}, | |||||
| {GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges}, | |||||
| {GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors}, | |||||
| {GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors}, | |||||
| {GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors}, | |||||
| {RANDOM_WALK, &GraphDataServiceImpl::RandomWalk}, | |||||
| {GET_NODE_FEATURE, &GraphDataServiceImpl::GetNodeFeature}, | |||||
| {GET_EDGE_FEATURE, &GraphDataServiceImpl::GetEdgeFeature}}; | |||||
| GraphDataServiceImpl::GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl) | |||||
| : server_(server), graph_data_impl_(graph_data_impl) {} | |||||
| Status GraphDataServiceImpl::FillDefaultFeature(GnnClientRegisterResponsePb *response) { | |||||
| const auto default_node_features = graph_data_impl_->GetAllDefaultNodeFeatures(); | |||||
| for (const auto feature : *default_node_features) { | |||||
| GnnFeatureInfoPb *feature_info = response->add_default_node_feature(); | |||||
| feature_info->set_type(feature.first); | |||||
| RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature())); | |||||
| } | |||||
| const auto default_edge_features = graph_data_impl_->GetAllDefaultEdgeFeatures(); | |||||
| for (const auto feature : *default_edge_features) { | |||||
| GnnFeatureInfoPb *feature_info = response->add_default_edge_feature(); | |||||
| feature_info->set_type(feature.first); | |||||
| RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature())); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context, | |||||
| const GnnClientRegisterRequestPb *request, | |||||
| GnnClientRegisterResponsePb *response) { | |||||
| Status s = server_->ClientRegister(request->pid()); | |||||
| if (s.IsOk()) { | |||||
| switch (server_->state()) { | |||||
| case GraphDataServer::kGdsUninit: | |||||
| case GraphDataServer::kGdsInitializing: | |||||
| response->set_error_msg("Initializing"); | |||||
| break; | |||||
| case GraphDataServer::kGdsRunning: | |||||
| response->set_error_msg("Success"); | |||||
| response->set_data_schema(graph_data_impl_->GetDataSchema()); | |||||
| response->set_shared_memory_key(graph_data_impl_->GetSharedMemoryKey()); | |||||
| response->set_shared_memory_size(graph_data_impl_->GetSharedMemorySize()); | |||||
| s = FillDefaultFeature(response); | |||||
| if (!s.IsOk()) { | |||||
| response->set_error_msg(s.ToString()); | |||||
| } | |||||
| break; | |||||
| case GraphDataServer::kGdsStopped: | |||||
| response->set_error_msg("Stoped"); | |||||
| break; | |||||
| } | |||||
| } else { | |||||
| response->set_error_msg(s.ToString()); | |||||
| } | |||||
| return ::grpc::Status::OK; | |||||
| } | |||||
| grpc::Status GraphDataServiceImpl::ClientUnRegister(grpc::ServerContext *context, | |||||
| const GnnClientUnRegisterRequestPb *request, | |||||
| GnnClientUnRegisterResponsePb *response) { | |||||
| Status s = server_->ClientUnRegister(request->pid()); | |||||
| if (s.IsOk()) { | |||||
| response->set_error_msg("Success"); | |||||
| } else { | |||||
| response->set_error_msg(s.ToString()); | |||||
| } | |||||
| return ::grpc::Status::OK; | |||||
| } | |||||
| grpc::Status GraphDataServiceImpl::GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request, | |||||
| GnnGraphDataResponsePb *response) { | |||||
| // MS_LOG(INFO) << "#### receive GetGraphData:" << request->op_name(); | |||||
| Status s; | |||||
| auto iter = g_get_graph_data_func_.find(request->op_name()); | |||||
| if (iter != g_get_graph_data_func_.end()) { | |||||
| pFunction func = iter->second; | |||||
| s = (this->*func)(request, response); | |||||
| if (s.IsOk()) { | |||||
| response->set_error_msg("Success"); | |||||
| } else { | |||||
| response->set_error_msg(s.ToString()); | |||||
| } | |||||
| } else { | |||||
| response->set_error_msg("Invalid op name."); | |||||
| } | |||||
| // MS_LOG(INFO) << "#### end receive GetGraphData:" << request->op_name(); | |||||
| return ::grpc::Status::OK; | |||||
| } | |||||
| grpc::Status GraphDataServiceImpl::GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request, | |||||
| GnnMetaInfoResponsePb *response) { | |||||
| MetaInfo meta_info; | |||||
| Status s = graph_data_impl_->GetMetaInfo(&meta_info); | |||||
| if (s.IsOk()) { | |||||
| response->set_error_msg("Success"); | |||||
| for (const auto &type : meta_info.node_type) { | |||||
| auto node_info = response->add_node_info(); | |||||
| node_info->set_type(static_cast<google::protobuf::int32>(type)); | |||||
| auto itr = meta_info.node_num.find(type); | |||||
| if (itr != meta_info.node_num.end()) { | |||||
| node_info->set_num(static_cast<google::protobuf::int32>(itr->second)); | |||||
| } else { | |||||
| node_info->set_num(0); | |||||
| } | |||||
| } | |||||
| for (const auto &type : meta_info.edge_type) { | |||||
| auto edge_info = response->add_edge_info(); | |||||
| edge_info->set_type(static_cast<google::protobuf::int32>(type)); | |||||
| auto itr = meta_info.edge_num.find(type); | |||||
| if (itr != meta_info.edge_num.end()) { | |||||
| edge_info->set_num(static_cast<google::protobuf::int32>(itr->second)); | |||||
| } else { | |||||
| edge_info->set_num(0); | |||||
| } | |||||
| } | |||||
| for (const auto &type : meta_info.node_feature_type) { | |||||
| response->add_node_feature_type(static_cast<google::protobuf::int32>(type)); | |||||
| } | |||||
| for (const auto &type : meta_info.edge_feature_type) { | |||||
| response->add_edge_feature_type(static_cast<google::protobuf::int32>(type)); | |||||
| } | |||||
| } else { | |||||
| response->set_error_msg(s.ToString()); | |||||
| } | |||||
| return ::grpc::Status::OK; | |||||
| } | |||||
| Status GraphDataServiceImpl::GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetAllNodes(static_cast<NodeType>(request->type()[0]), &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetAllEdges(static_cast<EdgeType>(request->type()[0]), &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input edge id is empty"); | |||||
| std::vector<EdgeIdType> edge_list; | |||||
| edge_list.resize(request->id().size()); | |||||
| std::transform(request->id().begin(), request->id().end(), edge_list.begin(), | |||||
| [](const google::protobuf::int32 id) { return static_cast<EdgeIdType>(id); }); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetNodesFromEdges(edge_list, &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | |||||
| std::vector<NodeIdType> node_list; | |||||
| node_list.resize(request->id().size()); | |||||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *request, | |||||
| GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() > 0, "The input neighbor number is empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input neighbor type is empty"); | |||||
| std::vector<NodeIdType> node_list; | |||||
| node_list.resize(request->id().size()); | |||||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||||
| std::vector<NodeIdType> neighbor_nums; | |||||
| neighbor_nums.resize(request->number().size()); | |||||
| std::transform(request->number().begin(), request->number().end(), neighbor_nums.begin(), | |||||
| [](const google::protobuf::int32 num) { return static_cast<NodeIdType>(num); }); | |||||
| std::vector<NodeType> neighbor_types; | |||||
| neighbor_types.resize(request->type().size()); | |||||
| std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(), | |||||
| [](const google::protobuf::int32 type) { return static_cast<NodeType>(type); }); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, | |||||
| GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() == 1, "The number of neighbor number is not 1"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of neighbor types is not 1"); | |||||
| std::vector<NodeIdType> node_list; | |||||
| node_list.resize(request->id().size()); | |||||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetNegSampledNeighbors(node_list, static_cast<NodeIdType>(request->number()[0]), | |||||
| static_cast<NodeType>(request->type()[0]), &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input meta path is empty"); | |||||
| std::vector<NodeIdType> node_list; | |||||
| node_list.resize(request->id().size()); | |||||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||||
| std::vector<NodeType> meta_path; | |||||
| meta_path.resize(request->type().size()); | |||||
| std::transform(request->type().begin(), request->type().end(), meta_path.begin(), | |||||
| [](const google::protobuf::int32 type) { return static_cast<NodeType>(type); }); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->RandomWalk(node_list, meta_path, request->random_walk().p(), | |||||
| request->random_walk().q(), request->random_walk().default_id(), | |||||
| &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| std::shared_ptr<Tensor> nodes; | |||||
| RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &nodes)); | |||||
| for (const auto &type : request->type()) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetNodeFeatureSharedMemory(nodes, type, &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphDataServiceImpl::GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||||
| std::shared_ptr<Tensor> edges; | |||||
| RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &edges)); | |||||
| for (const auto &type : request->type()) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(graph_data_impl_->GetEdgeFeatureSharedMemory(edges, type, &tensor)); | |||||
| TensorPb *result = response->add_result_data(); | |||||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||||
| #include "proto/gnn_graph_data.grpc.pb.h" | |||||
| #include "proto/gnn_graph_data.pb.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| class GraphDataServer; | |||||
| // class GraphDataServiceImpl : public GnnGraphData::Service { | |||||
| class GraphDataServiceImpl { | |||||
| public: | |||||
| GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl); | |||||
| ~GraphDataServiceImpl() = default; | |||||
| grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request, | |||||
| GnnClientRegisterResponsePb *response); | |||||
| grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request, | |||||
| GnnClientUnRegisterResponsePb *response); | |||||
| grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request, | |||||
| GnnGraphDataResponsePb *response); | |||||
| grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request, | |||||
| GnnMetaInfoResponsePb *response); | |||||
| Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||||
| private: | |||||
| Status FillDefaultFeature(GnnClientRegisterResponsePb *response); | |||||
| GraphDataServer *server_; | |||||
| GraphDataImpl *graph_data_impl_; | |||||
| }; | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ | |||||
| @@ -0,0 +1,106 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/gnn/graph_feature_parser.h" | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| using mindrecord::MSRStatus; | |||||
| GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) { | |||||
| shard_column_ = std::make_unique<ShardColumn>(shard_column); | |||||
| } | |||||
| Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob, | |||||
| std::shared_ptr<Tensor> *tensor) { | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<unsigned char[]> data_ptr; | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | |||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||||
| std::vector<int64_t> column_shape; | |||||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), | |||||
| std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), | |||||
| data, tensor)); | |||||
| return Status::OK(); | |||||
| } | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob, | |||||
| GraphSharedMemory *shared_memory, | |||||
| std::shared_ptr<Tensor> *out_tensor) { | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<unsigned char[]> data_ptr; | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | |||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||||
| std::vector<int64_t> column_shape; | |||||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); | |||||
| auto fea_itr = tensor->begin<int64_t>(); | |||||
| int64_t offset = 0; | |||||
| RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset)); | |||||
| *fea_itr = offset; | |||||
| ++fea_itr; | |||||
| *fea_itr = n_bytes; | |||||
| *out_tensor = std::move(tensor); | |||||
| return Status::OK(); | |||||
| } | |||||
| #endif | |||||
| Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob, | |||||
| std::vector<int32_t> *indices) { | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<unsigned char[]> data_ptr; | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | |||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||||
| std::vector<int64_t> column_shape; | |||||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||||
| for (int i = 0; i < n_bytes; i += col_type_size) { | |||||
| int32_t feature_ind = -1; | |||||
| if (col_type == mindrecord::ColumnInt32) { | |||||
| feature_ind = *(reinterpret_cast<const int32_t *>(data + i)); | |||||
| } else if (col_type == mindrecord::ColumnInt64) { | |||||
| feature_ind = *(reinterpret_cast<const int64_t *>(data + i)); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); | |||||
| } | |||||
| if (feature_ind >= 0) indices->push_back(feature_ind); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ | |||||
| #include <memory> | |||||
| #include <queue> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/data_type.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/gnn/feature.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/mindrecord/include/shard_column.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| using mindrecord::ShardColumn; | |||||
| class GraphFeatureParser { | |||||
| public: | |||||
| explicit GraphFeatureParser(const ShardColumn &shard_column); | |||||
| ~GraphFeatureParser() = default; | |||||
| // @param std::string key - column name | |||||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||||
| // @param std::vector<int32_t> *ind - return value, list of feature index in int32_t | |||||
| // @return Status - the status code | |||||
| Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind); | |||||
| // @param std::string &key - column name | |||||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||||
| // @param std::shared_ptr<Tensor> *tensor - return value feature tensor | |||||
| // @return Status - the status code | |||||
| Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor); | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob, | |||||
| GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor); | |||||
| #endif | |||||
| private: | |||||
| std::unique_ptr<ShardColumn> shard_column_; | |||||
| }; | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ | |||||
| @@ -13,41 +13,42 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||||
| #include <future> | #include <future> | ||||
| #include <tuple> | #include <tuple> | ||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||||
| #include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||||
| #include "minddata/dataset/engine/gnn/local_edge.h" | #include "minddata/dataset/engine/gnn/local_edge.h" | ||||
| #include "minddata/dataset/engine/gnn/local_node.h" | #include "minddata/dataset/engine/gnn/local_node.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| #include "minddata/mindrecord/include/shard_error.h" | |||||
| using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>; | using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| namespace gnn { | namespace gnn { | ||||
| using mindrecord::MSRStatus; | using mindrecord::MSRStatus; | ||||
| GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) | |||||
| : mr_path_(mr_filepath), | |||||
| GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers, bool server_mode) | |||||
| : graph_impl_(graph_impl), | |||||
| mr_path_(mr_filepath), | |||||
| num_workers_(num_workers), | num_workers_(num_workers), | ||||
| row_id_(0), | row_id_(0), | ||||
| shard_reader_(nullptr), | shard_reader_(nullptr), | ||||
| graph_feature_parser_(nullptr), | |||||
| keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} | keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} | ||||
| Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, | |||||
| EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, | |||||
| EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, | |||||
| DefaultEdgeFeatureMap *default_edge_feature_map) { | |||||
| Status GraphLoader::GetNodesAndEdges() { | |||||
| NodeIdMap *n_id_map = &graph_impl_->node_id_map_; | |||||
| EdgeIdMap *e_id_map = &graph_impl_->edge_id_map_; | |||||
| for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) { | for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) { | ||||
| while (dq.empty() == false) { | while (dq.empty() == false) { | ||||
| std::shared_ptr<Node> node_ptr = dq.front(); | std::shared_ptr<Node> node_ptr = dq.front(); | ||||
| n_id_map->insert({node_ptr->id(), node_ptr}); | n_id_map->insert({node_ptr->id(), node_ptr}); | ||||
| (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); | |||||
| graph_impl_->node_type_map_[node_ptr->type()].push_back(node_ptr->id()); | |||||
| dq.pop_front(); | dq.pop_front(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -63,15 +64,15 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N | |||||
| RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); | RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); | ||||
| RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); | RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); | ||||
| e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ | e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ | ||||
| (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); | |||||
| graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); | |||||
| dq.pop_front(); | dq.pop_front(); | ||||
| } | } | ||||
| } | } | ||||
| for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); | |||||
| for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); | |||||
| for (auto &itr : graph_impl_->node_type_map_) itr.second.shrink_to_fit(); | |||||
| for (auto &itr : graph_impl_->edge_type_map_) itr.second.shrink_to_fit(); | |||||
| MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); | |||||
| MergeFeatureMaps(); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -92,13 +93,26 @@ Status GraphLoader::InitAndLoad() { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); | CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); | CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); | ||||
| mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; | |||||
| graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); | |||||
| mindrecord::json schema = graph_impl_->data_schema_["schema"]; | |||||
| for (const std::string &key : keys_) { | for (const std::string &key : keys_) { | ||||
| if (schema.find(key) == schema.end()) { | if (schema.find(key) == schema.end()) { | ||||
| RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); | RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); | ||||
| } | } | ||||
| } | } | ||||
| if (graph_impl_->server_mode_) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| int64_t total_blob_size = 0; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS, | |||||
| "failed to get total blob size"); | |||||
| graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_); | |||||
| RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); | |||||
| #endif | |||||
| } | |||||
| graph_feature_parser_ = std::make_unique<GraphFeatureParser>(*shard_reader_->GetShardColumn()); | |||||
| // launching worker threads | // launching worker threads | ||||
| for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { | for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { | ||||
| RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); | RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); | ||||
| @@ -116,18 +130,39 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec | |||||
| NodeType node_type = static_cast<NodeType>(col_jsn["type"]); | NodeType node_type = static_cast<NodeType>(col_jsn["type"]); | ||||
| (*node) = std::make_shared<LocalNode>(node_id, node_type); | (*node) = std::make_shared<LocalNode>(node_id, node_type); | ||||
| std::vector<int32_t> indices; | std::vector<int32_t> indices; | ||||
| RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); | |||||
| for (int32_t ind : indices) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); | |||||
| RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||||
| (*feature_map)[node_type].insert(ind); | |||||
| if ((*default_feature)[ind] == nullptr) { | |||||
| std::shared_ptr<Tensor> zero_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices)); | |||||
| if (graph_impl_->server_mode_) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| for (int32_t ind : indices) { | |||||
| std::shared_ptr<Tensor> tensor_sm; | |||||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory( | |||||
| "node_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm)); | |||||
| RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true))); | |||||
| (*feature_map)[node_type].insert(ind); | |||||
| if ((*default_feature)[ind] == nullptr) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK( | |||||
| graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor)); | |||||
| std::shared_ptr<Tensor> zero_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } else { | |||||
| for (int32_t ind : indices) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK( | |||||
| graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor)); | |||||
| RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||||
| (*feature_map)[node_type].insert(ind); | |||||
| if ((*default_feature)[ind] == nullptr) { | |||||
| std::shared_ptr<Tensor> zero_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -143,63 +178,42 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec | |||||
| std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1); | std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1); | ||||
| (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst); | (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst); | ||||
| std::vector<int32_t> indices; | std::vector<int32_t> indices; | ||||
| RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); | |||||
| for (int32_t ind : indices) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); | |||||
| RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||||
| (*feature_map)[edge_type].insert(ind); | |||||
| if ((*default_feature)[ind] == nullptr) { | |||||
| std::shared_ptr<Tensor> zero_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices)); | |||||
| if (graph_impl_->server_mode_) { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| for (int32_t ind : indices) { | |||||
| std::shared_ptr<Tensor> tensor_sm; | |||||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory( | |||||
| "edge_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm)); | |||||
| RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true))); | |||||
| (*feature_map)[edge_type].insert(ind); | |||||
| if ((*default_feature)[ind] == nullptr) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK( | |||||
| graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor)); | |||||
| std::shared_ptr<Tensor> zero_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob, | |||||
| const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) { | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<unsigned char[]> data_ptr; | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | |||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||||
| std::vector<int64_t> column_shape; | |||||
| MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( | |||||
| key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), | |||||
| std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), | |||||
| data, tensor)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob, | |||||
| const mindrecord::json &col_jsn, std::vector<int32_t> *indices) { | |||||
| const unsigned char *data = nullptr; | |||||
| std::unique_ptr<unsigned char[]> data_ptr; | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | |||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||||
| std::vector<int64_t> column_shape; | |||||
| MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( | |||||
| key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||||
| for (int i = 0; i < n_bytes; i += col_type_size) { | |||||
| int32_t feature_ind = -1; | |||||
| if (col_type == mindrecord::ColumnInt32) { | |||||
| feature_ind = *(reinterpret_cast<const int32_t *>(data + i)); | |||||
| } else if (col_type == mindrecord::ColumnInt64) { | |||||
| feature_ind = *(reinterpret_cast<const int64_t *>(data + i)); | |||||
| } else { | |||||
| RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); | |||||
| #endif | |||||
| } else { | |||||
| for (int32_t ind : indices) { | |||||
| std::shared_ptr<Tensor> tensor; | |||||
| RETURN_IF_NOT_OK( | |||||
| graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor)); | |||||
| RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||||
| (*feature_map)[edge_type].insert(ind); | |||||
| if ((*default_feature)[ind] == nullptr) { | |||||
| std::shared_ptr<Tensor> zero_tensor; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||||
| } | |||||
| } | } | ||||
| if (feature_ind >= 0) indices->push_back(feature_ind); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -234,21 +248,19 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, | |||||
| DefaultNodeFeatureMap *default_node_feature_map, | |||||
| DefaultEdgeFeatureMap *default_edge_feature_map) { | |||||
| void GraphLoader::MergeFeatureMaps() { | |||||
| for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | ||||
| for (auto &m : n_feature_maps_[wkr_id]) { | for (auto &m : n_feature_maps_[wkr_id]) { | ||||
| for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); | |||||
| for (auto &n : m.second) graph_impl_->node_feature_map_[m.first].insert(n); | |||||
| } | } | ||||
| for (auto &m : e_feature_maps_[wkr_id]) { | for (auto &m : e_feature_maps_[wkr_id]) { | ||||
| for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); | |||||
| for (auto &n : m.second) graph_impl_->edge_feature_map_[m.first].insert(n); | |||||
| } | } | ||||
| for (auto &m : default_node_feature_maps_[wkr_id]) { | for (auto &m : default_node_feature_maps_[wkr_id]) { | ||||
| (*default_node_feature_map)[m.first] = m.second; | |||||
| graph_impl_->default_node_feature_map_[m.first] = m.second; | |||||
| } | } | ||||
| for (auto &m : default_edge_feature_maps_[wkr_id]) { | for (auto &m : default_edge_feature_maps_[wkr_id]) { | ||||
| (*default_edge_feature_map)[m.first] = m.second; | |||||
| graph_impl_->default_edge_feature_map_[m.first] = m.second; | |||||
| } | } | ||||
| } | } | ||||
| n_feature_maps_.clear(); | n_feature_maps_.clear(); | ||||
| @@ -26,10 +26,13 @@ | |||||
| #include "minddata/dataset/core/data_type.h" | #include "minddata/dataset/core/data_type.h" | ||||
| #include "minddata/dataset/core/tensor.h" | #include "minddata/dataset/core/tensor.h" | ||||
| #include "minddata/dataset/engine/gnn/edge.h" | |||||
| #include "minddata/dataset/engine/gnn/feature.h" | #include "minddata/dataset/engine/gnn/feature.h" | ||||
| #include "minddata/dataset/engine/gnn/graph.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_feature_parser.h" | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | |||||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||||
| #endif | |||||
| #include "minddata/dataset/engine/gnn/node.h" | #include "minddata/dataset/engine/gnn/node.h" | ||||
| #include "minddata/dataset/engine/gnn/edge.h" | |||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "minddata/mindrecord/include/shard_reader.h" | #include "minddata/mindrecord/include/shard_reader.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureTy | |||||
| using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | ||||
| using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | ||||
| class GraphDataImpl; | |||||
| // this class interfaces with the underlying storage format (mindrecord) | // this class interfaces with the underlying storage format (mindrecord) | ||||
| // it returns raw nodes and edges via GetNodesAndEdges | // it returns raw nodes and edges via GetNodesAndEdges | ||||
| // it is then the responsibility of graph to construct itself based on the nodes and edges | // it is then the responsibility of graph to construct itself based on the nodes and edges | ||||
| // if needed, this class could become a base where each derived class handles a specific storage format | // if needed, this class could become a base where each derived class handles a specific storage format | ||||
| class GraphLoader { | class GraphLoader { | ||||
| public: | public: | ||||
| explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); | |||||
| GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false); | |||||
| ~GraphLoader() = default; | ~GraphLoader() = default; | ||||
| // Init mindrecord and load everything into memory multi-threaded | // Init mindrecord and load everything into memory multi-threaded | ||||
| @@ -63,8 +68,7 @@ class GraphLoader { | |||||
| // nodes and edges are added to map without any connection. That's because there nodes and edges are read in | // nodes and edges are added to map without any connection. That's because there nodes and edges are read in | ||||
| // random order. src_node and dst_node in Edge are node_id only with -1 as type. | // random order. src_node and dst_node in Edge are node_id only with -1 as type. | ||||
| // features attached to each node and edge are expected to be filled correctly | // features attached to each node and edge are expected to be filled correctly | ||||
| Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, | |||||
| DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); | |||||
| Status GetNodesAndEdges(); | |||||
| private: | private: | ||||
| // | // | ||||
| @@ -92,29 +96,15 @@ class GraphLoader { | |||||
| Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge, | Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge, | ||||
| EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); | EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); | ||||
| // @param std::string key - column name | |||||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||||
| // @param mindrecord::json &jsn - contains raw data | |||||
| // @param std::vector<int32_t> *ind - return value, list of feature index in int32_t | |||||
| // @return Status - the status code | |||||
| Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn, | |||||
| std::vector<int32_t> *ind); | |||||
| // @param std::string &key - column name | |||||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||||
| // @param mindrecord::json &jsn - contains raw data | |||||
| // @param std::shared_ptr<Tensor> *tensor - return value feature tensor | |||||
| // @return Status - the status code | |||||
| Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn, | |||||
| std::shared_ptr<Tensor> *tensor); | |||||
| // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 | // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 | ||||
| void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); | |||||
| void MergeFeatureMaps(); | |||||
| GraphDataImpl *graph_impl_; | |||||
| std::string mr_path_; | |||||
| const int32_t num_workers_; | const int32_t num_workers_; | ||||
| std::atomic_int row_id_; | std::atomic_int row_id_; | ||||
| std::string mr_path_; | |||||
| std::unique_ptr<ShardReader> shard_reader_; | std::unique_ptr<ShardReader> shard_reader_; | ||||
| std::unique_ptr<GraphFeatureParser> graph_feature_parser_; | |||||
| std::vector<std::deque<std::shared_ptr<Node>>> n_deques_; | std::vector<std::deque<std::shared_ptr<Node>>> n_deques_; | ||||
| std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_; | std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_; | ||||
| std::vector<NodeFeatureMap> n_feature_maps_; | std::vector<NodeFeatureMap> n_feature_maps_; | ||||
| @@ -0,0 +1,134 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||||
| #include <string> | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key) | |||||
| : memory_size_(memory_size), | |||||
| memory_key_(memory_key), | |||||
| memory_ptr_(nullptr), | |||||
| memory_offset_(0), | |||||
| is_new_create_(false) { | |||||
| std::stringstream stream; | |||||
| stream << std::hex << memory_key_; | |||||
| memory_key_str_ = stream.str(); | |||||
| } | |||||
| GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file) | |||||
| : mr_file_(mr_file), | |||||
| memory_size_(memory_size), | |||||
| memory_key_(-1), | |||||
| memory_ptr_(nullptr), | |||||
| memory_offset_(0), | |||||
| is_new_create_(false) {} | |||||
| GraphSharedMemory::~GraphSharedMemory() { | |||||
| if (is_new_create_) { | |||||
| (void)DeleteSharedMemory(); | |||||
| } | |||||
| } | |||||
| Status GraphSharedMemory::CreateSharedMemory() { | |||||
| if (memory_key_ == -1) { | |||||
| // ftok to generate unique key | |||||
| memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_); | |||||
| std::stringstream stream; | |||||
| stream << std::hex << memory_key_; | |||||
| memory_key_str_ = stream.str(); | |||||
| } | |||||
| int shmflg = (0666 | IPC_CREAT | IPC_EXCL); | |||||
| Status s = SharedMemoryImpl(shmflg); | |||||
| if (s.IsOk()) { | |||||
| is_new_create_ = true; | |||||
| MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_; | |||||
| } else { | |||||
| MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_; | |||||
| shmflg = (0666 | IPC_CREAT); | |||||
| s = SharedMemoryImpl(shmflg); | |||||
| if (!s.IsOk()) { | |||||
| RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphSharedMemory::GetSharedMemory() { | |||||
| int shmflg = 0; | |||||
| RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg)); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphSharedMemory::DeleteSharedMemory() { | |||||
| int shmid = shmget(memory_key_, 0, 0); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_); | |||||
| int result = shmctl(shmid, IPC_RMID, 0); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) { | |||||
| // shmget returns an identifier in shmid | |||||
| int shmid = shmget(memory_key_, memory_size_, shmflg); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_); | |||||
| // shmat to attach to shared memory | |||||
| auto data = shmat(shmid, reinterpret_cast<void *>(0), 0); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_); | |||||
| memory_ptr_ = reinterpret_cast<uint8_t *>(data); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid."); | |||||
| std::lock_guard<std::mutex> lck(mutex_); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len), | |||||
| "Insufficient shared memory space to insert data."); | |||||
| if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) { | |||||
| RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory."); | |||||
| } | |||||
| *offset = memory_offset_; | |||||
| memory_offset_ += len; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset, | |||||
| "get_data_len is too large, beyond the space of shared memory."); | |||||
| if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) { | |||||
| RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory."); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,72 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ | |||||
| #include <sys/ipc.h> | |||||
| #include <sys/shm.h> | |||||
| #include <mutex> | |||||
| #include <string> | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| namespace gnn { | |||||
| const int kGnnSharedMemoryId = 65; | |||||
| class GraphSharedMemory { | |||||
| public: | |||||
| explicit GraphSharedMemory(int64_t memory_size, key_t memory_key); | |||||
| explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file); | |||||
| ~GraphSharedMemory(); | |||||
| // @param uint8_t** shared_memory - shared memory address | |||||
| // @return Status - the status code | |||||
| Status CreateSharedMemory(); | |||||
| // @param uint8_t** shared_memory - shared memory address | |||||
| // @return Status - the status code | |||||
| Status GetSharedMemory(); | |||||
| Status DeleteSharedMemory(); | |||||
| Status InsertData(const uint8_t *data, int64_t len, int64_t *offset); | |||||
| Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len); | |||||
| key_t memory_key() { return memory_key_; } | |||||
| int64_t memory_size() { return memory_size_; } | |||||
| private: | |||||
| Status SharedMemoryImpl(const int &shmflg); | |||||
| std::string mr_file_; | |||||
| int64_t memory_size_; | |||||
| key_t memory_key_; | |||||
| std::string memory_key_str_; | |||||
| uint8_t *memory_ptr_; | |||||
| int64_t memory_offset_; | |||||
| std::mutex mutex_; | |||||
| bool is_new_create_; | |||||
| }; | |||||
| } // namespace gnn | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ | |||||
| @@ -0,0 +1,82 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/gnn/grpc_async_server.h" | |||||
| #include <limits> | |||||
| #include "minddata/dataset/util/task_manager.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {} | |||||
| GrpcAsyncServer::~GrpcAsyncServer() { Stop(); } | |||||
| Status GrpcAsyncServer::Run() { | |||||
| std::string server_address = host_ + ":" + std::to_string(port_); | |||||
| grpc::ServerBuilder builder; | |||||
| // Default message size for gRPC is 4MB. Increase it to 2g-1 | |||||
| builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max()); | |||||
| builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); | |||||
| int port_tcpip = 0; | |||||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip); | |||||
| RETURN_IF_NOT_OK(RegisterService(&builder)); | |||||
| cq_ = builder.AddCompletionQueue(); | |||||
| server_ = builder.BuildAndStart(); | |||||
| if (server_) { | |||||
| MS_LOG(INFO) << "Server listening on " << server_address; | |||||
| } else { | |||||
| std::string errMsg = "Fail to start server. "; | |||||
| if (port_tcpip != port_) { | |||||
| errMsg += "Unable to bind to address " + server_address + "."; | |||||
| } | |||||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GrpcAsyncServer::HandleRequest() { | |||||
| bool success; | |||||
| void *tag; | |||||
| // We loop through the grpc queue. Each connection if successful | |||||
| // will come back with our own tag which is an instance of CallData | |||||
| // and we simply call its functor. But first we need to create these instances | |||||
| // and inject them into the grpc queue. | |||||
| RETURN_IF_NOT_OK(EnqueueRequest()); | |||||
| while (cq_->Next(&tag, &success)) { | |||||
| RETURN_IF_INTERRUPTED(); | |||||
| if (success) { | |||||
| RETURN_IF_NOT_OK(ProcessRequest(tag)); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "cq_->Next failed."; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| void GrpcAsyncServer::Stop() { | |||||
| if (server_) { | |||||
| server_->Shutdown(); | |||||
| } | |||||
| // Always shutdown the completion queue after the server. | |||||
| if (cq_) { | |||||
| cq_->Shutdown(); | |||||
| } | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "grpcpp/grpcpp.h" | |||||
| #include "grpcpp/impl/codegen/async_unary_call.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief Async server base class | |||||
| class GrpcAsyncServer { | |||||
| public: | |||||
| explicit GrpcAsyncServer(const std::string &host, int32_t port); | |||||
| virtual ~GrpcAsyncServer(); | |||||
| /// \brief Brings up gRPC server | |||||
| /// \return none | |||||
| Status Run(); | |||||
| /// \brief Entry function to handle async server request | |||||
| Status HandleRequest(); | |||||
| void Stop(); | |||||
| virtual Status RegisterService(grpc::ServerBuilder *builder) = 0; | |||||
| virtual Status EnqueueRequest() = 0; | |||||
| virtual Status ProcessRequest(void *tag) = 0; | |||||
| protected: | |||||
| int32_t port_; | |||||
| std::string host_; | |||||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_; | |||||
| std::unique_ptr<grpc::Server> server_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ | |||||
| @@ -44,6 +44,7 @@ Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } | } | ||||
| } // namespace gnn | } // namespace gnn | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,10 +20,10 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/engine/gnn/edge.h" | #include "minddata/dataset/engine/gnn/edge.h" | ||||
| #include "minddata/dataset/engine/gnn/feature.h" | #include "minddata/dataset/engine/gnn/feature.h" | ||||
| #include "minddata/dataset/engine/gnn/node.h" | #include "minddata/dataset/engine/gnn/node.h" | ||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -20,9 +20,9 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/engine/gnn/node.h" | #include "minddata/dataset/engine/gnn/node.h" | ||||
| #include "minddata/dataset/engine/gnn/feature.h" | #include "minddata/dataset/engine/gnn/feature.h" | ||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -20,8 +20,8 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/engine/gnn/feature.h" | #include "minddata/dataset/engine/gnn/feature.h" | ||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| const std::unordered_map<DataTypePb, DataType::Type> g_pb2datatype_map{ | |||||
| {DataTypePb::DE_PB_UNKNOWN, DataType::DE_UNKNOWN}, {DataTypePb::DE_PB_BOOL, DataType::DE_BOOL}, | |||||
| {DataTypePb::DE_PB_INT8, DataType::DE_INT8}, {DataTypePb::DE_PB_UINT8, DataType::DE_UINT8}, | |||||
| {DataTypePb::DE_PB_INT16, DataType::DE_INT16}, {DataTypePb::DE_PB_UINT16, DataType::DE_UINT16}, | |||||
| {DataTypePb::DE_PB_INT32, DataType::DE_INT32}, {DataTypePb::DE_PB_UINT32, DataType::DE_UINT32}, | |||||
| {DataTypePb::DE_PB_INT64, DataType::DE_INT64}, {DataTypePb::DE_PB_UINT64, DataType::DE_UINT64}, | |||||
| {DataTypePb::DE_PB_FLOAT16, DataType::DE_FLOAT16}, {DataTypePb::DE_PB_FLOAT32, DataType::DE_FLOAT32}, | |||||
| {DataTypePb::DE_PB_FLOAT64, DataType::DE_FLOAT64}, {DataTypePb::DE_PB_STRING, DataType::DE_STRING}, | |||||
| }; | |||||
| const std::unordered_map<DataType::Type, DataTypePb> g_datatype2pb_map{ | |||||
| {DataType::DE_UNKNOWN, DataTypePb::DE_PB_UNKNOWN}, {DataType::DE_BOOL, DataTypePb::DE_PB_BOOL}, | |||||
| {DataType::DE_INT8, DataTypePb::DE_PB_INT8}, {DataType::DE_UINT8, DataTypePb::DE_PB_UINT8}, | |||||
| {DataType::DE_INT16, DataTypePb::DE_PB_INT16}, {DataType::DE_UINT16, DataTypePb::DE_PB_UINT16}, | |||||
| {DataType::DE_INT32, DataTypePb::DE_PB_INT32}, {DataType::DE_UINT32, DataTypePb::DE_PB_UINT32}, | |||||
| {DataType::DE_INT64, DataTypePb::DE_PB_INT64}, {DataType::DE_UINT64, DataTypePb::DE_PB_UINT64}, | |||||
| {DataType::DE_FLOAT16, DataTypePb::DE_PB_FLOAT16}, {DataType::DE_FLOAT32, DataTypePb::DE_PB_FLOAT32}, | |||||
| {DataType::DE_FLOAT64, DataTypePb::DE_PB_FLOAT64}, {DataType::DE_STRING, DataTypePb::DE_PB_STRING}, | |||||
| }; | |||||
| Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer"); | |||||
| std::vector<dsize_t> shape = tensor->shape().AsVector(); | |||||
| for (auto dim : shape) { | |||||
| tensor_pb->add_dims(static_cast<google::protobuf::int64>(dim)); | |||||
| } | |||||
| auto iter = g_datatype2pb_map.find(tensor->type().value()); | |||||
| if (iter == g_datatype2pb_map.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid tensor type: " + tensor->type().ToString()); | |||||
| } | |||||
| tensor_pb->set_tensor_type(iter->second); | |||||
| tensor_pb->set_data(tensor->GetBuffer(), tensor->SizeInBytes()); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer"); | |||||
| std::vector<dsize_t> shape; | |||||
| shape.resize(tensor_pb->dims().size()); | |||||
| std::transform(tensor_pb->dims().begin(), tensor_pb->dims().end(), shape.begin(), | |||||
| [](const google::protobuf::int64 dim) { return static_cast<dsize_t>(dim); }); | |||||
| auto iter = g_pb2datatype_map.find(tensor_pb->tensor_type()); | |||||
| if (iter == g_pb2datatype_map.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid Tensor_pb type: " + std::to_string(tensor_pb->tensor_type())); | |||||
| } | |||||
| DataType::Type type = iter->second; | |||||
| std::shared_ptr<Tensor> tensor_out; | |||||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(shape), DataType(type), | |||||
| reinterpret_cast<const unsigned char *>(tensor_pb->data().data()), | |||||
| tensor_pb->data().size(), &tensor_out)); | |||||
| *tensor = std::move(tensor_out); | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ | |||||
| #include <deque> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "proto/gnn_tensor.pb.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb); | |||||
| Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ | |||||
| @@ -61,6 +61,7 @@ const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = { | |||||
| class ShardColumn { | class ShardColumn { | ||||
| public: | public: | ||||
| explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true); | explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true); | ||||
| explicit ShardColumn(const json &schema_json, bool compress_integer = true); | |||||
| ~ShardColumn() = default; | ~ShardColumn() = default; | ||||
| @@ -72,23 +73,29 @@ class ShardColumn { | |||||
| std::vector<int64_t> *column_shape); | std::vector<int64_t> *column_shape); | ||||
| /// \brief compress blob | /// \brief compress blob | ||||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob); | |||||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size); | |||||
| /// \brief check if blob compressed | /// \brief check if blob compressed | ||||
| bool CheckCompressBlob() const { return has_compress_blob_; } | bool CheckCompressBlob() const { return has_compress_blob_; } | ||||
| /// \brief getter | |||||
| uint64_t GetNumBlobColumn() const { return num_blob_column_; } | uint64_t GetNumBlobColumn() const { return num_blob_column_; } | ||||
| /// \brief getter | |||||
| std::vector<std::string> GetColumnName() { return column_name_; } | std::vector<std::string> GetColumnName() { return column_name_; } | ||||
| /// \brief getter | |||||
| std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; } | std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; } | ||||
| /// \brief getter | |||||
| std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | ||||
| /// \brief get column value from blob | /// \brief get column value from blob | ||||
| MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | ||||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | ||||
| uint64_t *const n_bytes); | uint64_t *const n_bytes); | ||||
| /// \brief get column type | |||||
| std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name, | std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name, | ||||
| ColumnDataType *column_data_type, | ColumnDataType *column_data_type, | ||||
| uint64_t *column_data_type_size, | uint64_t *column_data_type_size, | ||||
| @@ -99,6 +106,9 @@ class ShardColumn { | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | ||||
| private: | private: | ||||
| /// \brief intialization | |||||
| void Init(const json &schema_json, bool compress_integer = true); | |||||
| /// \brief get float value from json | /// \brief get float value from json | ||||
| template <typename T> | template <typename T> | ||||
| MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | ||||
| @@ -65,6 +65,11 @@ class ShardHeader { | |||||
| /// \return the Statistic | /// \return the Statistic | ||||
| std::vector<std::shared_ptr<Statistics>> GetStatistics(); | std::vector<std::shared_ptr<Statistics>> GetStatistics(); | ||||
| /// \brief add the statistic and save it | |||||
| /// \param[in] statistic info of slim size | |||||
| /// \return null | |||||
| int64_t GetSlimSizeStatistic(const json &slim_size_json); | |||||
| /// \brief get the fields of the index | /// \brief get the fields of the index | ||||
| /// \return the fields of the index | /// \return the fields of the index | ||||
| std::vector<std::pair<uint64_t, std::string>> GetFields(); | std::vector<std::pair<uint64_t, std::string>> GetFields(); | ||||
| @@ -114,10 +119,14 @@ class ShardHeader { | |||||
| uint64_t GetPageSize() const { return page_size_; } | uint64_t GetPageSize() const { return page_size_; } | ||||
| uint64_t GetCompressionSize() const { return compression_size_; } | |||||
| void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } | void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } | ||||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | ||||
| void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; } | |||||
| std::vector<std::string> SerializeHeader(); | std::vector<std::string> SerializeHeader(); | ||||
| MSRStatus PagesToFile(const std::string dump_file_name); | MSRStatus PagesToFile(const std::string dump_file_name); | ||||
| @@ -177,6 +186,7 @@ class ShardHeader { | |||||
| uint32_t shard_count_; | uint32_t shard_count_; | ||||
| uint64_t header_size_; | uint64_t header_size_; | ||||
| uint64_t page_size_; | uint64_t page_size_; | ||||
| uint64_t compression_size_; | |||||
| std::shared_ptr<Index> index_; | std::shared_ptr<Index> index_; | ||||
| std::vector<std::string> shard_addresses_; | std::vector<std::string> shard_addresses_; | ||||
| @@ -209,6 +209,9 @@ class ShardReader { | |||||
| /// \brief get all classes | /// \brief get all classes | ||||
| MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); | MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); | ||||
| /// \brief get the size of blob data | |||||
| MSRStatus GetTotalBlobSize(int64_t *total_blob_size); | |||||
| protected: | protected: | ||||
| /// \brief sqlite call back function | /// \brief sqlite call back function | ||||
| static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | ||||
| @@ -323,6 +326,7 @@ class ShardReader { | |||||
| const std::string kThreadName = "THRD_ITER_"; // prefix of thread name | const std::string kThreadName = "THRD_ITER_"; // prefix of thread name | ||||
| std::vector<std::thread> thread_set_; // thread list | std::vector<std::thread> thread_set_; // thread list | ||||
| int num_rows_; // number of rows | int num_rows_; // number of rows | ||||
| int64_t total_blob_size_; // total size of blob data | |||||
| std::mutex mtx_delivery_; // locker for delivery | std::mutex mtx_delivery_; // locker for delivery | ||||
| std::condition_variable cv_delivery_; // conditional variable for delivery | std::condition_variable cv_delivery_; // conditional variable for delivery | ||||
| std::condition_variable cv_iterator_; // conditional variable for iterator | std::condition_variable cv_iterator_; // conditional variable for iterator | ||||
| @@ -257,6 +257,7 @@ class ShardWriter { | |||||
| std::mutex check_mutex_; // mutex for data check | std::mutex check_mutex_; // mutex for data check | ||||
| std::atomic<bool> flag_{false}; | std::atomic<bool> flag_{false}; | ||||
| std::atomic<int64_t> compression_size_; | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,6 +43,7 @@ ShardReader::ShardReader() { | |||||
| page_size_ = 0; | page_size_ = 0; | ||||
| header_size_ = 0; | header_size_ = 0; | ||||
| num_rows_ = 0; | num_rows_ = 0; | ||||
| total_blob_size_ = 0; | |||||
| num_padded_ = 0; | num_padded_ = 0; | ||||
| } | } | ||||
| @@ -55,9 +56,11 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s | |||||
| return {FAILED, {}}; | return {FAILED, {}}; | ||||
| } | } | ||||
| auto header = ret.second; | auto header = ret.second; | ||||
| meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, | |||||
| {"version", header["version"]}, {"index_fields", header["index_fields"]}, | |||||
| {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; | |||||
| uint64_t compression_size = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; | |||||
| meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, | |||||
| {"compression_size", compression_size}, {"version", header["version"]}, | |||||
| {"index_fields", header["index_fields"]}, {"schema", header["schema"]}, | |||||
| {"blob_fields", header["blob_fields"]}}; | |||||
| return {SUCCESS, header["shard_addresses"]}; | return {SUCCESS, header["shard_addresses"]}; | ||||
| } | } | ||||
| @@ -145,6 +148,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa | |||||
| for (const auto &rg : row_group_summary) { | for (const auto &rg : row_group_summary) { | ||||
| num_rows_ += std::get<3>(rg); | num_rows_ += std::get<3>(rg); | ||||
| } | } | ||||
| auto disk_size = page_size_ * row_group_summary.size(); | |||||
| auto compression_size = shard_header_->GetCompressionSize(); | |||||
| total_blob_size_ = disk_size + compression_size; | |||||
| MS_LOG(INFO) << "Blob data size, on disk: " << disk_size << " , addtional uncompression: " << compression_size | |||||
| << " , Total: " << total_blob_size_; | |||||
| MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; | MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; | ||||
| @@ -272,6 +280,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar | |||||
| return row_group_summary; | return row_group_summary; | ||||
| } | } | ||||
| MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) { | |||||
| *total_blob_size = total_blob_size_; | |||||
| return SUCCESS; | |||||
| } | |||||
| MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, | MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, | ||||
| std::shared_ptr<std::fstream> fs, | std::shared_ptr<std::fstream> fs, | ||||
| std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id, | std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id, | ||||
| @@ -28,11 +28,9 @@ using mindspore::MsLogLevel::INFO; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardWriter::ShardWriter() | ShardWriter::ShardWriter() | ||||
| : shard_count_(1), | |||||
| header_size_(kDefaultHeaderSize), | |||||
| page_size_(kDefaultPageSize), | |||||
| row_count_(0), | |||||
| schema_count_(1) {} | |||||
| : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) { | |||||
| compression_size_ = 0; | |||||
| } | |||||
| ShardWriter::~ShardWriter() { | ShardWriter::~ShardWriter() { | ||||
| for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) { | for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) { | ||||
| @@ -201,6 +199,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||||
| if (ret == FAILED) { | if (ret == FAILED) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| compression_size_ = shard_header_->GetCompressionSize(); | |||||
| ret = Open(real_addresses, true); | ret = Open(real_addresses, true); | ||||
| if (ret == FAILED) { | if (ret == FAILED) { | ||||
| MS_LOG(ERROR) << "Open file failed"; | MS_LOG(ERROR) << "Open file failed"; | ||||
| @@ -614,7 +613,9 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json> | |||||
| // compress blob | // compress blob | ||||
| if (shard_column_->CheckCompressBlob()) { | if (shard_column_->CheckCompressBlob()) { | ||||
| for (auto &blob : blob_data) { | for (auto &blob : blob_data) { | ||||
| blob = shard_column_->CompressBlob(blob); | |||||
| int64_t compression_bytes = 0; | |||||
| blob = shard_column_->CompressBlob(blob, &compression_bytes); | |||||
| compression_size_ += compression_bytes; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1177,6 +1178,11 @@ MSRStatus ShardWriter::WriteShardHeader() { | |||||
| MS_LOG(ERROR) << "Shard header is null"; | MS_LOG(ERROR) << "Shard header is null"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int64_t compression_temp = compression_size_; | |||||
| uint64_t compression_size = compression_temp > 0 ? compression_temp : 0; | |||||
| shard_header_->SetCompressionSize(compression_size); | |||||
| auto shard_header = shard_header_->SerializeHeader(); | auto shard_header = shard_header_->SerializeHeader(); | ||||
| // Write header data to multi files | // Write header data to multi files | ||||
| if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) { | if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) { | ||||
| @@ -24,7 +24,15 @@ namespace mindspore { | |||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) { | ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) { | ||||
| auto first_schema = shard_header->GetSchemas()[0]; | auto first_schema = shard_header->GetSchemas()[0]; | ||||
| auto schema = first_schema->GetSchema()["schema"]; | |||||
| json schema_json = first_schema->GetSchema(); | |||||
| Init(schema_json, compress_integer); | |||||
| } | |||||
| ShardColumn::ShardColumn(const json &schema_json, bool compress_integer) { Init(schema_json, compress_integer); } | |||||
| void ShardColumn::Init(const json &schema_json, bool compress_integer) { | |||||
| auto schema = schema_json["schema"]; | |||||
| auto blob_fields = schema_json["blob_fields"]; | |||||
| bool has_integer_array = false; | bool has_integer_array = false; | ||||
| for (json::iterator it = schema.begin(); it != schema.end(); ++it) { | for (json::iterator it = schema.begin(); it != schema.end(); ++it) { | ||||
| @@ -52,8 +60,6 @@ ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool | |||||
| column_name_id_[column_name_[i]] = i; | column_name_id_[column_name_[i]] = i; | ||||
| } | } | ||||
| auto blob_fields = first_schema->GetBlobFields(); | |||||
| for (const auto &field : blob_fields) { | for (const auto &field : blob_fields) { | ||||
| blob_column_.push_back(field); | blob_column_.push_back(field); | ||||
| } | } | ||||
| @@ -282,8 +288,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | |||||
| return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; | return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; | ||||
| } | } | ||||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) { | |||||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) { | |||||
| // Skip if no compress columns | // Skip if no compress columns | ||||
| *compression_size = 0; | |||||
| if (!CheckCompressBlob()) return blob; | if (!CheckCompressBlob()) return blob; | ||||
| std::vector<uint8_t> dst_blob; | std::vector<uint8_t> dst_blob; | ||||
| @@ -295,7 +302,9 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) | |||||
| // Compress and return is blob has 1 column only | // Compress and return is blob has 1 column only | ||||
| if (num_blob_column_ == 1) { | if (num_blob_column_ == 1) { | ||||
| return CompressInt(blob, int_type); | |||||
| dst_blob = CompressInt(blob, int_type); | |||||
| *compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size()); | |||||
| return dst_blob; | |||||
| } | } | ||||
| // Just copy and continue if column dat type is not int32/int64 | // Just copy and continue if column dat type is not int32/int64 | ||||
| @@ -319,6 +328,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) | |||||
| i_src += kInt64Len + num_bytes; | i_src += kInt64Len + num_bytes; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; | MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; | ||||
| *compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size()); | |||||
| return dst_blob; | return dst_blob; | ||||
| } | } | ||||
| @@ -33,7 +33,9 @@ using mindspore::MsLogLevel::ERROR; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| std::atomic<bool> thread_status(false); | std::atomic<bool> thread_status(false); | ||||
| ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); } | |||||
| ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) { | |||||
| index_ = std::make_shared<Index>(); | |||||
| } | |||||
| MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | ||||
| shard_count_ = headers.size(); | shard_count_ = headers.size(); | ||||
| @@ -54,6 +56,7 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l | |||||
| ParseShardAddress(header["shard_addresses"]); | ParseShardAddress(header["shard_addresses"]); | ||||
| header_size_ = header["header_size"].get<uint64_t>(); | header_size_ = header["header_size"].get<uint64_t>(); | ||||
| page_size_ = header["page_size"].get<uint64_t>(); | page_size_ = header["page_size"].get<uint64_t>(); | ||||
| compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; | |||||
| } | } | ||||
| if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { | if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -146,9 +149,12 @@ std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &fil | |||||
| return {FAILED, json()}; | return {FAILED, json()}; | ||||
| } | } | ||||
| json raw_header = ret.second; | json raw_header = ret.second; | ||||
| uint64_t compression_size = | |||||
| raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0; | |||||
| json header = {{"shard_addresses", raw_header["shard_addresses"]}, | json header = {{"shard_addresses", raw_header["shard_addresses"]}, | ||||
| {"header_size", raw_header["header_size"]}, | {"header_size", raw_header["header_size"]}, | ||||
| {"page_size", raw_header["page_size"]}, | {"page_size", raw_header["page_size"]}, | ||||
| {"compression_size", compression_size}, | |||||
| {"index_fields", raw_header["index_fields"]}, | {"index_fields", raw_header["index_fields"]}, | ||||
| {"blob_fields", raw_header["schema"][0]["blob_fields"]}, | {"blob_fields", raw_header["schema"][0]["blob_fields"]}, | ||||
| {"schema", raw_header["schema"][0]["schema"]}, | {"schema", raw_header["schema"][0]["schema"]}, | ||||
| @@ -343,6 +349,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() { | |||||
| s += "\"index_fields\":" + index + ","; | s += "\"index_fields\":" + index + ","; | ||||
| s += "\"page\":" + pages[shardId] + ","; | s += "\"page\":" + pages[shardId] + ","; | ||||
| s += "\"page_size\":" + std::to_string(page_size_) + ","; | s += "\"page_size\":" + std::to_string(page_size_) + ","; | ||||
| s += "\"compression_size\":" + std::to_string(compression_size_) + ","; | |||||
| s += "\"schema\":" + schema + ","; | s += "\"schema\":" + schema + ","; | ||||
| s += "\"shard_addresses\":" + address + ","; | s += "\"shard_addresses\":" + address + ","; | ||||
| s += "\"shard_id\":" + std::to_string(shardId) + ","; | s += "\"shard_id\":" + std::to_string(shardId) + ","; | ||||
| @@ -3083,20 +3083,22 @@ def _cpp_sampler_fn(sampler, dataset): | |||||
| yield tuple([np.array(x, copy=False) for x in val]) | yield tuple([np.array(x, copy=False) for x in val]) | ||||
| def _cpp_sampler_fn_mp(sampler, dataset, num_worker): | |||||
| def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process): | |||||
| """ | """ | ||||
| Multiprocessing generator function wrapper for mappable dataset with cpp sampler. | Multiprocessing generator function wrapper for mappable dataset with cpp sampler. | ||||
| """ | """ | ||||
| indices = sampler.get_indices() | indices = sampler.get_indices() | ||||
| return _sampler_fn_mp(indices, dataset, num_worker) | |||||
| sample_fn = SamplerFn(dataset, num_worker, multi_process) | |||||
| return sample_fn.process(indices) | |||||
| def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): | |||||
| def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process): | |||||
| """ | """ | ||||
| Multiprocessing generator function wrapper for mappable dataset with python sampler. | Multiprocessing generator function wrapper for mappable dataset with python sampler. | ||||
| """ | """ | ||||
| indices = _fetch_py_sampler_indices(sampler, num_samples) | indices = _fetch_py_sampler_indices(sampler, num_samples) | ||||
| return _sampler_fn_mp(indices, dataset, num_worker) | |||||
| sample_fn = SamplerFn(dataset, num_worker, multi_process) | |||||
| return sample_fn.process(indices) | |||||
| def _fetch_py_sampler_indices(sampler, num_samples): | def _fetch_py_sampler_indices(sampler, num_samples): | ||||
| @@ -3130,63 +3132,92 @@ def _fill_worker_indices(workers, indices, idx): | |||||
| return idx | return idx | ||||
| def _sampler_fn_mp(indices, dataset, num_worker): | |||||
| class SamplerFn: | |||||
| """ | """ | ||||
| Multiprocessing generator function wrapper master process. | |||||
| Multiprocessing or multithread generator function wrapper master process. | |||||
| """ | """ | ||||
| workers = [] | |||||
| # Event for end of epoch | |||||
| eoe = multiprocessing.Event() | |||||
| # Create workers | |||||
| for _ in range(num_worker): | |||||
| worker = _GeneratorWorker(dataset, eoe) | |||||
| worker.daemon = True | |||||
| workers.append(worker) | |||||
| # Fill initial index queues | |||||
| idx_cursor = 0 | |||||
| idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) | |||||
| # Start all workers | |||||
| for w in workers: | |||||
| w.start() | |||||
| # Fetch results | |||||
| for i in range(len(indices)): | |||||
| # Fetch result and put index | |||||
| try: | |||||
| result = workers[i % num_worker].get() | |||||
| except queue.Empty: | |||||
| raise Exception("Generator worker process timeout") | |||||
| except KeyboardInterrupt: | |||||
| for w in workers: | |||||
| w.terminate() | |||||
| def __init__(self, dataset, num_worker, multi_process): | |||||
| self.workers = [] | |||||
| self.num_worker = num_worker | |||||
| self.multi_process = multi_process | |||||
| # Event for end of epoch | |||||
| if multi_process is True: | |||||
| self.eoe = multiprocessing.Event() | |||||
| self.eof = multiprocessing.Event() | |||||
| else: | |||||
| self.eoe = threading.Event() | |||||
| self.eof = threading.Event() | |||||
| # Create workers | |||||
| for _ in range(num_worker): | |||||
| if multi_process is True: | |||||
| worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) | |||||
| else: | |||||
| worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) | |||||
| worker.daemon = True | |||||
| self.workers.append(worker) | |||||
| def process(self, indices): | |||||
| """ | |||||
| The main process, start the child process or child thread, and fill the index queue, | |||||
| get the result from the result and return. | |||||
| """ | |||||
| # Fill initial index queues | |||||
| idx_cursor = 0 | |||||
| idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) | |||||
| # Start all workers | |||||
| for w in self.workers: | |||||
| w.start() | |||||
| # Fetch results | |||||
| for i in range(len(indices)): | |||||
| # Fetch result and put index | |||||
| try: | |||||
| result = self.workers[i % self.num_worker].get() | |||||
| except queue.Empty: | |||||
| raise Exception("Generator worker process timeout") | |||||
| except KeyboardInterrupt: | |||||
| self.eof.set() | |||||
| for w in self.workers: | |||||
| w.terminate() | |||||
| w.join() | |||||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||||
| if idx_cursor < len(indices): | |||||
| idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) | |||||
| # Set eoe event once all indices are sent | |||||
| if idx_cursor == len(indices) and not self.eoe.is_set(): | |||||
| self.eoe.set() | |||||
| yield tuple([np.array(x, copy=False) for x in result]) | |||||
| def __del__(self): | |||||
| self.eoe.set() | |||||
| self.eof.set() | |||||
| if self.multi_process is False: | |||||
| for w in self.workers: | |||||
| w.join() | w.join() | ||||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||||
| if idx_cursor < len(indices): | |||||
| idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) | |||||
| # Set eoe event once all indices are sent | |||||
| if idx_cursor == len(indices) and not eoe.is_set(): | |||||
| eoe.set() | |||||
| yield tuple([np.array(x, copy=False) for x in result]) | |||||
| def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): | |||||
| def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): | |||||
| """ | """ | ||||
| Multiprocessing generator worker process loop. | |||||
| Multiprocessing or multithread generator worker process loop. | |||||
| """ | """ | ||||
| while True: | while True: | ||||
| # Fetch index, block | # Fetch index, block | ||||
| try: | try: | ||||
| idx = idx_queue.get() | |||||
| idx = idx_queue.get(timeout=10) | |||||
| except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
| raise Exception("Generator worker receives KeyboardInterrupt") | raise Exception("Generator worker receives KeyboardInterrupt") | ||||
| except queue.Empty: | |||||
| if eof.is_set() or eoe.is_set(): | |||||
| raise Exception("Generator worker receives queue.Empty") | |||||
| continue | |||||
| if idx is None: | if idx is None: | ||||
| # When the queue is out of scope from master process, a None item can be fetched from the queue. | # When the queue is out of scope from master process, a None item can be fetched from the queue. | ||||
| # Upon receiving None, worker process should check if EOE is set. | # Upon receiving None, worker process should check if EOE is set. | ||||
| assert eoe.is_set(), "" | assert eoe.is_set(), "" | ||||
| return | return | ||||
| if eof.is_set(): | |||||
| return | |||||
| # Fetch data, any exception from __getitem__ will terminate worker and timeout master process | # Fetch data, any exception from __getitem__ will terminate worker and timeout master process | ||||
| result = dataset[idx] | result = dataset[idx] | ||||
| # Send data, block | # Send data, block | ||||
| @@ -3195,17 +3226,42 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): | |||||
| except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
| raise Exception("Generator worker receives KeyboardInterrupt") | raise Exception("Generator worker receives KeyboardInterrupt") | ||||
| del result, idx | del result, idx | ||||
| if eoe.is_set() and idx_queue.empty(): | |||||
| return | |||||
| class _GeneratorWorker(multiprocessing.Process): | |||||
| class _GeneratorWorkerMt(threading.Thread): | |||||
| """ | |||||
| Worker process for multithread Generator. | |||||
| """ | |||||
| def __init__(self, dataset, eoe, eof): | |||||
| self.idx_queue = queue.Queue(16) | |||||
| self.res_queue = queue.Queue(16) | |||||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||||
| def put(self, item): | |||||
| """ | |||||
| Put function for worker index queue. Never block. Raise queue.Full on failure. | |||||
| """ | |||||
| self.idx_queue.put_nowait(item) | |||||
| def get(self): | |||||
| """ | |||||
| Get function for worker result queue. Block with timeout. | |||||
| """ | |||||
| return self.res_queue.get(timeout=10) | |||||
| class _GeneratorWorkerMp(multiprocessing.Process): | |||||
| """ | """ | ||||
| Worker process for multiprocess Generator. | Worker process for multiprocess Generator. | ||||
| """ | """ | ||||
| def __init__(self, dataset, eoe): | |||||
| def __init__(self, dataset, eoe, eof): | |||||
| self.idx_queue = multiprocessing.Queue(16) | self.idx_queue = multiprocessing.Queue(16) | ||||
| self.res_queue = multiprocessing.Queue(16) | self.res_queue = multiprocessing.Queue(16) | ||||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) | |||||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||||
| def put(self, item): | def put(self, item): | ||||
| """ | """ | ||||
| @@ -3217,7 +3273,7 @@ class _GeneratorWorker(multiprocessing.Process): | |||||
| """ | """ | ||||
| Get function for worker result queue. Block with timeout. | Get function for worker result queue. Block with timeout. | ||||
| """ | """ | ||||
| return self.res_queue.get() | |||||
| return self.res_queue.get(timeout=10) | |||||
| def __del__(self): | def __del__(self): | ||||
| self.terminate() | self.terminate() | ||||
| @@ -3280,6 +3336,8 @@ class GeneratorDataset(MappableDataset): | |||||
| When this argument is specified, 'num_samples' will not effect. Random accessible input is required. | When this argument is specified, 'num_samples' will not effect. Random accessible input is required. | ||||
| shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only | shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only | ||||
| when num_shards is also specified. Random accessible input is required. | when num_shards is also specified. Random accessible input is required. | ||||
| python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This | |||||
| option could be beneficial if the python operation is computational heavy (default=True). | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -3316,12 +3374,14 @@ class GeneratorDataset(MappableDataset): | |||||
| @check_generatordataset | @check_generatordataset | ||||
| def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, | def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, | ||||
| num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||||
| num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, | |||||
| python_multiprocessing=True): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.source = source | self.source = source | ||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.python_multiprocessing = python_multiprocessing | |||||
| if column_names is not None and not isinstance(column_names, list): | if column_names is not None and not isinstance(column_names, list): | ||||
| column_names = [column_names] | column_names = [column_names] | ||||
| @@ -3403,12 +3463,16 @@ class GeneratorDataset(MappableDataset): | |||||
| sampler_instance.set_num_rows(len(self.source)) | sampler_instance.set_num_rows(len(self.source)) | ||||
| sampler_instance.initialize() | sampler_instance.initialize() | ||||
| if new_op.num_parallel_workers > 1: | if new_op.num_parallel_workers > 1: | ||||
| new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers)) | |||||
| new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, | |||||
| new_op.num_parallel_workers, | |||||
| self.python_multiprocessing)) | |||||
| else: | else: | ||||
| new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) | new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) | ||||
| else: | else: | ||||
| if new_op.num_parallel_workers > 1: | if new_op.num_parallel_workers > 1: | ||||
| new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers)) | |||||
| new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, | |||||
| new_op.num_parallel_workers, | |||||
| self.python_multiprocessing)) | |||||
| else: | else: | ||||
| new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) | new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) | ||||
| else: | else: | ||||
| @@ -16,8 +16,11 @@ | |||||
| graphdata.py supports loading graph dataset for GNN network training, | graphdata.py supports loading graph dataset for GNN network training, | ||||
| and provides operations related to graph data. | and provides operations related to graph data. | ||||
| """ | """ | ||||
| import atexit | |||||
| import time | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore._c_dataengine import Graph | |||||
| from mindspore._c_dataengine import GraphDataClient | |||||
| from mindspore._c_dataengine import GraphDataServer | |||||
| from mindspore._c_dataengine import Tensor | from mindspore._c_dataengine import Tensor | ||||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | ||||
| @@ -34,14 +37,52 @@ class GraphData: | |||||
| dataset_file (str): One of file names in dataset. | dataset_file (str): One of file names in dataset. | ||||
| num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel | num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel | ||||
| (default=None). | (default=None). | ||||
| working_mode (str, optional): Set working mode, now support 'local'/'client'/'server' (default='local'). | |||||
| - 'local', used in non-distributed training scenarios. | |||||
| - 'client', used in distributed training scenarios, the client does not load data, | |||||
| but obtains data from the server. | |||||
| - 'server', used in distributed training scenarios, the server loads the data | |||||
| and is available to the client. | |||||
| hostname (str, optional): Valid when working_mode is set to 'client' or 'server', | |||||
| set the hostname of the graph data server (default='127.0.0.1'). | |||||
| port (int, optional): Valid when working_mode is set to 'client' or 'server', | |||||
| set the port of the graph data server, the range is 1024-65535 (default=50051). | |||||
| num_client (int, optional): Valid when working_mode is set to 'server', | |||||
| set the number of clients expected to connect, and the server will allocate corresponding | |||||
| resources according to this parameter (default=1). | |||||
| auto_shutdown (bool, optional): Valid when working_mode is set to 'server', | |||||
| Control when all clients have connected and no client connected to the server, | |||||
| automatically exit the server (default=True). | |||||
| """ | """ | ||||
| @check_gnn_graphdata | @check_gnn_graphdata | ||||
| def __init__(self, dataset_file, num_parallel_workers=None): | |||||
| def __init__(self, dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, | |||||
| num_client=1, auto_shutdown=True): | |||||
| self._dataset_file = dataset_file | self._dataset_file = dataset_file | ||||
| self._working_mode = working_mode | |||||
| if num_parallel_workers is None: | if num_parallel_workers is None: | ||||
| num_parallel_workers = 1 | num_parallel_workers = 1 | ||||
| self._graph = Graph(dataset_file, num_parallel_workers) | |||||
| def stop(): | |||||
| self._graph_data.stop() | |||||
| atexit.register(stop) | |||||
| if working_mode in ['local', 'client']: | |||||
| self._graph_data = GraphDataClient(dataset_file, num_parallel_workers, working_mode, hostname, port) | |||||
| if working_mode == 'server': | |||||
| self._graph_data = GraphDataServer( | |||||
| dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) | |||||
| try: | |||||
| while self._graph_data.is_stoped() is not True: | |||||
| time.sleep(1) | |||||
| except KeyboardInterrupt: | |||||
| # self._graph_data.stop() | |||||
| raise Exception("Graph data server receives KeyboardInterrupt") | |||||
| @check_gnn_get_all_nodes | @check_gnn_get_all_nodes | ||||
| def get_all_nodes(self, node_type): | def get_all_nodes(self, node_type): | ||||
| @@ -62,7 +103,9 @@ class GraphData: | |||||
| Raises: | Raises: | ||||
| TypeError: If `node_type` is not integer. | TypeError: If `node_type` is not integer. | ||||
| """ | """ | ||||
| return self._graph.get_all_nodes(node_type).as_array() | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.get_all_nodes(node_type).as_array() | |||||
| @check_gnn_get_all_edges | @check_gnn_get_all_edges | ||||
| def get_all_edges(self, edge_type): | def get_all_edges(self, edge_type): | ||||
| @@ -83,7 +126,9 @@ class GraphData: | |||||
| Raises: | Raises: | ||||
| TypeError: If `edge_type` is not integer. | TypeError: If `edge_type` is not integer. | ||||
| """ | """ | ||||
| return self._graph.get_all_edges(edge_type).as_array() | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.get_all_edges(edge_type).as_array() | |||||
| @check_gnn_get_nodes_from_edges | @check_gnn_get_nodes_from_edges | ||||
| def get_nodes_from_edges(self, edge_list): | def get_nodes_from_edges(self, edge_list): | ||||
| @@ -99,7 +144,9 @@ class GraphData: | |||||
| Raises: | Raises: | ||||
| TypeError: If `edge_list` is not list or ndarray. | TypeError: If `edge_list` is not list or ndarray. | ||||
| """ | """ | ||||
| return self._graph.get_nodes_from_edges(edge_list).as_array() | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.get_nodes_from_edges(edge_list).as_array() | |||||
| @check_gnn_get_all_neighbors | @check_gnn_get_all_neighbors | ||||
| def get_all_neighbors(self, node_list, neighbor_type): | def get_all_neighbors(self, node_list, neighbor_type): | ||||
| @@ -123,7 +170,9 @@ class GraphData: | |||||
| TypeError: If `node_list` is not list or ndarray. | TypeError: If `node_list` is not list or ndarray. | ||||
| TypeError: If `neighbor_type` is not integer. | TypeError: If `neighbor_type` is not integer. | ||||
| """ | """ | ||||
| return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() | |||||
| @check_gnn_get_sampled_neighbors | @check_gnn_get_sampled_neighbors | ||||
| def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | ||||
| @@ -155,7 +204,9 @@ class GraphData: | |||||
| TypeError: If `neighbor_nums` is not list or ndarray. | TypeError: If `neighbor_nums` is not list or ndarray. | ||||
| TypeError: If `neighbor_types` is not list or ndarray. | TypeError: If `neighbor_types` is not list or ndarray. | ||||
| """ | """ | ||||
| return self._graph.get_sampled_neighbors( | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.get_sampled_neighbors( | |||||
| node_list, neighbor_nums, neighbor_types).as_array() | node_list, neighbor_nums, neighbor_types).as_array() | ||||
| @check_gnn_get_neg_sampled_neighbors | @check_gnn_get_neg_sampled_neighbors | ||||
| @@ -182,7 +233,9 @@ class GraphData: | |||||
| TypeError: If `neg_neighbor_num` is not integer. | TypeError: If `neg_neighbor_num` is not integer. | ||||
| TypeError: If `neg_neighbor_type` is not integer. | TypeError: If `neg_neighbor_type` is not integer. | ||||
| """ | """ | ||||
| return self._graph.get_neg_sampled_neighbors( | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.get_neg_sampled_neighbors( | |||||
| node_list, neg_neighbor_num, neg_neighbor_type).as_array() | node_list, neg_neighbor_num, neg_neighbor_type).as_array() | ||||
| @check_gnn_get_node_feature | @check_gnn_get_node_feature | ||||
| @@ -207,10 +260,12 @@ class GraphData: | |||||
| TypeError: If `node_list` is not list or ndarray. | TypeError: If `node_list` is not list or ndarray. | ||||
| TypeError: If `feature_types` is not list or ndarray. | TypeError: If `feature_types` is not list or ndarray. | ||||
| """ | """ | ||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| if isinstance(node_list, list): | if isinstance(node_list, list): | ||||
| node_list = np.array(node_list, dtype=np.int32) | node_list = np.array(node_list, dtype=np.int32) | ||||
| return [ | return [ | ||||
| t.as_array() for t in self._graph.get_node_feature( | |||||
| t.as_array() for t in self._graph_data.get_node_feature( | |||||
| Tensor(node_list), | Tensor(node_list), | ||||
| feature_types)] | feature_types)] | ||||
| @@ -236,10 +291,12 @@ class GraphData: | |||||
| TypeError: If `edge_list` is not list or ndarray. | TypeError: If `edge_list` is not list or ndarray. | ||||
| TypeError: If `feature_types` is not list or ndarray. | TypeError: If `feature_types` is not list or ndarray. | ||||
| """ | """ | ||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| if isinstance(edge_list, list): | if isinstance(edge_list, list): | ||||
| edge_list = np.array(edge_list, dtype=np.int32) | edge_list = np.array(edge_list, dtype=np.int32) | ||||
| return [ | return [ | ||||
| t.as_array() for t in self._graph.get_edge_feature( | |||||
| t.as_array() for t in self._graph_data.get_edge_feature( | |||||
| Tensor(edge_list), | Tensor(edge_list), | ||||
| feature_types)] | feature_types)] | ||||
| @@ -252,7 +309,9 @@ class GraphData: | |||||
| dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | ||||
| node_feature_type and edge_feature_type. | node_feature_type and edge_feature_type. | ||||
| """ | """ | ||||
| return self._graph.graph_info() | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.graph_info() | |||||
| @check_gnn_random_walk | @check_gnn_random_walk | ||||
| def random_walk( | def random_walk( | ||||
| @@ -285,5 +344,7 @@ class GraphData: | |||||
| TypeError: If `target_nodes` is not list or ndarray. | TypeError: If `target_nodes` is not list or ndarray. | ||||
| TypeError: If `meta_path` is not list or ndarray. | TypeError: If `meta_path` is not list or ndarray. | ||||
| """ | """ | ||||
| return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, | |||||
| default_node).as_array() | |||||
| if self._working_mode == 'server': | |||||
| raise Exception("This method is not supported when working mode is server") | |||||
| return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param, | |||||
| default_node).as_array() | |||||
| @@ -18,6 +18,7 @@ Built-in validators. | |||||
| """ | """ | ||||
| import inspect as ins | import inspect as ins | ||||
| import os | import os | ||||
| import re | |||||
| from functools import wraps | from functools import wraps | ||||
| import numpy as np | import numpy as np | ||||
| @@ -912,16 +913,36 @@ def check_split(method): | |||||
| return new_method | return new_method | ||||
| def check_hostname(hostname): | |||||
| if len(hostname) > 255: | |||||
| return False | |||||
| if hostname[-1] == ".": | |||||
| hostname = hostname[:-1] # strip exactly one dot from the right, if present | |||||
| allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE) | |||||
| return all(allowed.match(x) for x in hostname.split(".")) | |||||
| def check_gnn_graphdata(method): | def check_gnn_graphdata(method): | ||||
| """check the input arguments of graphdata.""" | """check the input arguments of graphdata.""" | ||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| [dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) | |||||
| [dataset_file, num_parallel_workers, working_mode, hostname, | |||||
| port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs) | |||||
| check_file(dataset_file) | check_file(dataset_file) | ||||
| if num_parallel_workers is not None: | if num_parallel_workers is not None: | ||||
| check_num_parallel_workers(num_parallel_workers) | check_num_parallel_workers(num_parallel_workers) | ||||
| type_check(hostname, (str,), "hostname") | |||||
| if check_hostname(hostname) is False: | |||||
| raise ValueError("The hostname is illegal") | |||||
| type_check(working_mode, (str,), "working_mode") | |||||
| if working_mode not in {'local', 'client', 'server'}: | |||||
| raise ValueError("Invalid working mode") | |||||
| type_check(port, (int,), "port") | |||||
| check_value(port, (1024, 65535), "port") | |||||
| type_check(num_client, (int,), "num_client") | |||||
| check_value(num_client, (1, 255), "num_client") | |||||
| type_check(auto_shutdown, (bool,), "auto_shutdown") | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -15,6 +15,7 @@ | |||||
| """ | """ | ||||
| User-defined API for MindRecord GNN writer. | User-defined API for MindRecord GNN writer. | ||||
| """ | """ | ||||
| import numpy as np | |||||
| social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], | social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], | ||||
| [348, 336], [348, 337], [348, 338], [348, 340], [348, 341], | [348, 336], [348, 337], [348, 338], [348, 340], [348, 341], | ||||
| [348, 342], [348, 343], [348, 344], [348, 345], [348, 346], | [348, 342], [348, 343], [348, 344], [348, 345], [348, 346], | ||||
| @@ -29,7 +30,7 @@ social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], | |||||
| [355, 352], [353, 350], [352, 349], [351, 349], [350, 349]] | [355, 352], [353, 350], [352, 349], [351, 349], [350, 349]] | ||||
| # profile: (num_features, feature_data_types, feature_shapes) | # profile: (num_features, feature_data_types, feature_shapes) | ||||
| node_profile = (0, [], []) | |||||
| node_profile = (2, ["int64", "int32"], [[-1], [-1]]) | |||||
| edge_profile = (0, [], []) | edge_profile = (0, [], []) | ||||
| @@ -51,7 +52,9 @@ def yield_nodes(task_id=0): | |||||
| node_list.sort() | node_list.sort() | ||||
| print(node_list) | print(node_list) | ||||
| for node_id in node_list: | for node_id in node_list: | ||||
| node = {'id': node_id, 'type': 1} | |||||
| node = {'id': node_id, 'type': 1, | |||||
| 'feature_1': np.ones((5,), dtype=np.int64), | |||||
| 'feature_2': np.ones((10,), dtype=np.int32)} | |||||
| yield node | yield node | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "minddata/dataset/engine/gnn/node.h" | #include "minddata/dataset/engine/gnn/node.h" | ||||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | #include "minddata/dataset/engine/gnn/graph_loader.h" | ||||
| using namespace mindspore::dataset; | using namespace mindspore::dataset; | ||||
| @@ -39,30 +40,9 @@ class MindDataTestGNNGraph : public UT::Common { | |||||
| MindDataTestGNNGraph() = default; | MindDataTestGNNGraph() = default; | ||||
| }; | }; | ||||
| TEST_F(MindDataTestGNNGraph, TestGraphLoader) { | |||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||||
| GraphLoader gl(path, 4); | |||||
| EXPECT_TRUE(gl.InitAndLoad().IsOk()); | |||||
| NodeIdMap n_id_map; | |||||
| EdgeIdMap e_id_map; | |||||
| NodeTypeMap n_type_map; | |||||
| EdgeTypeMap e_type_map; | |||||
| NodeFeatureMap n_feature_map; | |||||
| EdgeFeatureMap e_feature_map; | |||||
| DefaultNodeFeatureMap default_node_feature_map; | |||||
| DefaultEdgeFeatureMap default_edge_feature_map; | |||||
| EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map, | |||||
| &default_node_feature_map, &default_edge_feature_map) | |||||
| .IsOk()); | |||||
| EXPECT_EQ(n_id_map.size(), 20); | |||||
| EXPECT_EQ(e_id_map.size(), 40); | |||||
| EXPECT_EQ(n_type_map[2].size(), 10); | |||||
| EXPECT_EQ(n_type_map[1].size(), 10); | |||||
| } | |||||
| TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | ||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | std::string path = "data/mindrecord/testGraphData/testdata"; | ||||
| Graph graph(path, 1); | |||||
| GraphDataImpl graph(path, 1); | |||||
| Status s = graph.Init(); | Status s = graph.Init(); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| @@ -103,7 +83,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||||
| TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | ||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | std::string path = "data/mindrecord/testGraphData/testdata"; | ||||
| Graph graph(path, 1); | |||||
| GraphDataImpl graph(path, 1); | |||||
| Status s = graph.Init(); | Status s = graph.Init(); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| @@ -194,7 +174,7 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||||
| TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | ||||
| std::string path = "data/mindrecord/testGraphData/testdata"; | std::string path = "data/mindrecord/testGraphData/testdata"; | ||||
| Graph graph(path, 1); | |||||
| GraphDataImpl graph(path, 1); | |||||
| Status s = graph.Init(); | Status s = graph.Init(); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| @@ -237,7 +217,7 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | |||||
| TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | ||||
| std::string path = "data/mindrecord/testGraphData/sns"; | std::string path = "data/mindrecord/testGraphData/sns"; | ||||
| Graph graph(path, 1); | |||||
| GraphDataImpl graph(path, 1); | |||||
| Status s = graph.Init(); | Status s = graph.Init(); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| @@ -263,7 +243,7 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | |||||
| TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { | TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { | ||||
| std::string path = "data/mindrecord/testGraphData/sns"; | std::string path = "data/mindrecord/testGraphData/sns"; | ||||
| Graph graph(path, 1); | |||||
| GraphDataImpl graph(path, 1); | |||||
| Status s = graph.Init(); | Status s = graph.Init(); | ||||
| EXPECT_TRUE(s.IsOk()); | EXPECT_TRUE(s.IsOk()); | ||||
| @@ -0,0 +1,125 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import random | |||||
| import time | |||||
| from multiprocessing import Process | |||||
| import numpy as np | |||||
| import mindspore.dataset as ds | |||||
| from mindspore import log as logger | |||||
| DATASET_FILE = "../data/mindrecord/testGraphData/testdata" | |||||
| def graphdata_startserver(): | |||||
| """ | |||||
| start graphdata server | |||||
| """ | |||||
| logger.info('test start server.\n') | |||||
| ds.GraphData(DATASET_FILE, 1, 'server') | |||||
| class RandomBatchedSampler(ds.Sampler): | |||||
| # RandomBatchedSampler generate random sequence without replacement in a batched manner | |||||
| def __init__(self, index_range, num_edges_per_sample): | |||||
| super().__init__() | |||||
| self.index_range = index_range | |||||
| self.num_edges_per_sample = num_edges_per_sample | |||||
| def __iter__(self): | |||||
| indices = [i+1 for i in range(self.index_range)] | |||||
| # Reset random seed here if necessary | |||||
| # random.seed(0) | |||||
| random.shuffle(indices) | |||||
| for i in range(0, self.index_range, self.num_edges_per_sample): | |||||
| # Drop reminder | |||||
| if i + self.num_edges_per_sample <= self.index_range: | |||||
| yield indices[i: i + self.num_edges_per_sample] | |||||
| class GNNGraphDataset(): | |||||
| def __init__(self, g, batch_num): | |||||
| self.g = g | |||||
| self.batch_num = batch_num | |||||
| def __len__(self): | |||||
| # Total sample size of GNN dataset | |||||
| # In this case, the size should be total_num_edges/num_edges_per_sample | |||||
| return self.g.graph_info()['edge_num'][0] // self.batch_num | |||||
| def __getitem__(self, index): | |||||
| # index will be a list of indices yielded from RandomBatchedSampler | |||||
| # Fetch edges/nodes/samples/features based on indices | |||||
| nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) | |||||
| nodes = nodes[:, 0] | |||||
| neg_nodes = self.g.get_neg_sampled_neighbors( | |||||
| node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) | |||||
| nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ | |||||
| 2, 2], neighbor_types=[2, 1]) | |||||
| neg_nodes_neighbors = self.g.get_sampled_neighbors( | |||||
| node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) | |||||
| nodes_neighbors_features = self.g.get_node_feature( | |||||
| node_list=nodes_neighbors, feature_types=[2, 3]) | |||||
| neg_neighbors_features = self.g.get_node_feature( | |||||
| node_list=neg_nodes_neighbors, feature_types=[2, 3]) | |||||
| return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] | |||||
| def test_graphdata_distributed(): | |||||
| """ | |||||
| Test distributed | |||||
| """ | |||||
| logger.info('test distributed.\n') | |||||
| p1 = Process(target=graphdata_startserver) | |||||
| p1.start() | |||||
| time.sleep(2) | |||||
| g = ds.GraphData(DATASET_FILE, 1, 'client') | |||||
| nodes = g.get_all_nodes(1) | |||||
| assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] | |||||
| row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) | |||||
| assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0], | |||||
| [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], | |||||
| [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]] | |||||
| assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4] | |||||
| edges = g.get_all_edges(0) | |||||
| assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, | |||||
| 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] | |||||
| features = g.get_edge_feature(edges, [1, 2]) | |||||
| assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, | |||||
| 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] | |||||
| batch_num = 2 | |||||
| edge_num = g.graph_info()['edge_num'][0] | |||||
| out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] | |||||
| dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, | |||||
| sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4, | |||||
| python_multiprocessing=False) | |||||
| dataset = dataset.repeat(2) | |||||
| itr = dataset.create_dict_iterator() | |||||
| i = 0 | |||||
| for data in itr: | |||||
| assert data['neighbors'].shape == (2, 7) | |||||
| assert data['neg_neighbors'].shape == (6, 7) | |||||
| assert data['neighbors_features'].shape == (2, 7) | |||||
| assert data['neg_neighbors_features'].shape == (6, 7) | |||||
| i += 1 | |||||
| assert i == 40 | |||||
| if __name__ == '__main__': | |||||
| test_graphdata_distributed() | |||||