| @@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake) | |||
| include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) | |||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | |||
| SET(MS_BUILD_GRPC 0) | |||
| if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES) | |||
| SET(MS_BUILD_GRPC 1) | |||
| endif() | |||
| if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| SET(MS_BUILD_GRPC 1) | |||
| endif() | |||
| if ("${MS_BUILD_GRPC}") | |||
| # build dependencies of gRPC | |||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) | |||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) | |||
| @@ -83,6 +83,7 @@ endif() | |||
| if (ENABLE_TDTQUE) | |||
| add_dependencies(engine-tdt core) | |||
| endif () | |||
| ################### Create _c_dataengine Library ###################### | |||
| set(submodules | |||
| $<TARGET_OBJECTS:core> | |||
| @@ -182,3 +183,7 @@ else() | |||
| set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) | |||
| endif () | |||
| endif() | |||
| if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) | |||
| endif() | |||
| @@ -18,83 +18,103 @@ | |||
| #include "pybind11/stl_bind.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/engine/gnn/graph.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_client.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_server.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER( | |||
| Graph, 0, ([](const py::module *m) { | |||
| (void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph") | |||
| .def(py::init([](std::string dataset_file, int32_t num_workers) { | |||
| std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers); | |||
| THROW_IF_ERROR(g_out->Init()); | |||
| return g_out; | |||
| (void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient") | |||
| .def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode, | |||
| const std::string &hostname, int32_t port) { | |||
| std::shared_ptr<gnn::GraphData> out; | |||
| if (working_mode == "local") { | |||
| out = std::make_shared<gnn::GraphDataImpl>(dataset_file, num_workers); | |||
| } else if (working_mode == "client") { | |||
| out = std::make_shared<gnn::GraphDataClient>(dataset_file, hostname, port); | |||
| } | |||
| THROW_IF_ERROR(out->Init()); | |||
| return out; | |||
| })) | |||
| .def("get_all_nodes", | |||
| [](gnn::Graph &g, gnn::NodeType node_type) { | |||
| [](gnn::GraphData &g, gnn::NodeType node_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_all_edges", | |||
| [](gnn::Graph &g, gnn::EdgeType edge_type) { | |||
| [](gnn::GraphData &g, gnn::EdgeType edge_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_nodes_from_edges", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) { | |||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_all_neighbors", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { | |||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_sampled_neighbors", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, | |||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, | |||
| std::vector<gnn::NodeType> neighbor_types) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_neg_sampled_neighbors", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, | |||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, | |||
| gnn::NodeType neg_neighbor_type) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); | |||
| return out; | |||
| }) | |||
| .def("get_node_feature", | |||
| [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { | |||
| [](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { | |||
| TensorRow out; | |||
| THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); | |||
| return out.getRow(); | |||
| }) | |||
| .def("get_edge_feature", | |||
| [](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) { | |||
| [](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) { | |||
| TensorRow out; | |||
| THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); | |||
| return out.getRow(); | |||
| }) | |||
| .def("graph_info", | |||
| [](gnn::Graph &g) { | |||
| [](gnn::GraphData &g) { | |||
| py::dict out; | |||
| THROW_IF_ERROR(g.GraphInfo(&out)); | |||
| return out; | |||
| }) | |||
| .def("random_walk", | |||
| [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path, | |||
| [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path, | |||
| float step_home_param, float step_away_param, gnn::NodeIdType default_node) { | |||
| std::shared_ptr<Tensor> out; | |||
| THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); | |||
| return out; | |||
| }); | |||
| }) | |||
| .def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); }); | |||
| (void)py::class_<gnn::GraphDataServer, std::shared_ptr<gnn::GraphDataServer>>(*m, "GraphDataServer") | |||
| .def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port, | |||
| int32_t client_num, bool auto_shutdown) { | |||
| std::shared_ptr<gnn::GraphDataServer> out; | |||
| out = | |||
| std::make_shared<gnn::GraphDataServer>(dataset_file, num_workers, hostname, port, client_num, auto_shutdown); | |||
| THROW_IF_ERROR(out->Init()); | |||
| return out; | |||
| })) | |||
| .def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); }) | |||
| .def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); }); | |||
| })); | |||
| } // namespace dataset | |||
| @@ -1,9 +1,29 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-gnn OBJECT | |||
| graph.cc | |||
| set(DATASET_ENGINE_GNN_SRC_FILES | |||
| graph_data_impl.cc | |||
| graph_data_client.cc | |||
| graph_data_server.cc | |||
| graph_loader.cc | |||
| graph_feature_parser.cc | |||
| local_node.cc | |||
| local_edge.cc | |||
| feature.cc | |||
| ) | |||
| ) | |||
| if (CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES}) | |||
| else() | |||
| set(DATASET_ENGINE_GNN_SRC_FILES | |||
| ${DATASET_ENGINE_GNN_SRC_FILES} | |||
| tensor_proto.cc | |||
| grpc_async_server.cc | |||
| graph_data_service_impl.cc | |||
| graph_shared_memory.cc) | |||
| ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto") | |||
| ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto") | |||
| add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS}) | |||
| add_dependencies(engine-gnn mindspore::protobuf) | |||
| endif() | |||
| @@ -19,7 +19,8 @@ namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {} | |||
| Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory) | |||
| : type_name_(type_name), value_(value), is_shared_memory_(is_shared_memory) {} | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| @@ -31,7 +31,7 @@ class Feature { | |||
| // Constructor | |||
| // @param FeatureType type_name - feature type | |||
| // @param std::shared_ptr<Tensor> value - feature value | |||
| Feature(FeatureType type_name, std::shared_ptr<Tensor> value); | |||
| Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory = false); | |||
| ~Feature() = default; | |||
| @@ -45,6 +45,7 @@ class Feature { | |||
| private: | |||
| FeatureType type_name_; | |||
| std::shared_ptr<Tensor> value_; | |||
| bool is_shared_memory_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| @@ -0,0 +1,103 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| package mindspore.dataset; | |||
| import "gnn_tensor.proto"; | |||
| message GnnClientRegisterRequestPb { | |||
| int32 pid = 1; | |||
| } | |||
| message GnnFeatureInfoPb { | |||
| int32 type = 1; | |||
| TensorPb feature = 2; | |||
| } | |||
| message GnnClientRegisterResponsePb { | |||
| string error_msg = 1; | |||
| string data_schema = 2; | |||
| int64 shared_memory_key = 3; | |||
| int64 shared_memory_size = 4; | |||
| repeated GnnFeatureInfoPb default_node_feature = 5; | |||
| repeated GnnFeatureInfoPb default_edge_feature = 6; | |||
| } | |||
| message GnnClientUnRegisterRequestPb { | |||
| int32 pid = 1; | |||
| } | |||
| message GnnClientUnRegisterResponsePb { | |||
| string error_msg = 1; | |||
| } | |||
| enum GnnOpName { | |||
| GET_ALL_NODES = 0; | |||
| GET_ALL_EDGES = 1; | |||
| GET_NODES_FROM_EDGES = 2; | |||
| GET_ALL_NEIGHBORS = 3; | |||
| GET_SAMPLED_NEIGHBORS = 4; | |||
| GET_NEG_SAMPLED_NEIGHBORS = 5; | |||
| RANDOM_WALK = 6; | |||
| GET_NODE_FEATURE = 7; | |||
| GET_EDGE_FEATURE = 8; | |||
| } | |||
| message GnnRandomWalkPb { | |||
| float p = 1; | |||
| float q = 2; | |||
| int32 default_id = 3; | |||
| } | |||
| message GnnGraphDataRequestPb { | |||
| GnnOpName op_name = 1; | |||
| repeated int32 id = 2; // node id or edge id | |||
| repeated int32 type = 3; //node type or edge type or neighbor type or feature type | |||
| repeated int32 number = 4; // samples number | |||
| TensorPb id_tensor = 5; // input ids ,node id or edge id | |||
| GnnRandomWalkPb random_walk = 6; | |||
| } | |||
| message GnnGraphDataResponsePb { | |||
| string error_msg = 1; | |||
| repeated TensorPb result_data = 2; | |||
| } | |||
| message GnnMetaInfoRequestPb { | |||
| } | |||
| message GnnNodeEdgeInfoPb { | |||
| int32 type = 1; | |||
| int32 num = 2; | |||
| } | |||
| message GnnMetaInfoResponsePb { | |||
| string error_msg = 1; | |||
| repeated GnnNodeEdgeInfoPb node_info = 2; | |||
| repeated GnnNodeEdgeInfoPb edge_info = 3; | |||
| repeated int32 node_feature_type = 4; | |||
| repeated int32 edge_feature_type = 5; | |||
| } | |||
| service GnnGraphData { | |||
| rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb); | |||
| rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb); | |||
| rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb); | |||
| rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb); | |||
| } | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| package mindspore.dataset; | |||
| enum DataTypePb { | |||
| DE_PB_UNKNOWN = 0; | |||
| DE_PB_BOOL = 1; | |||
| DE_PB_INT8 = 2; | |||
| DE_PB_UINT8 = 3; | |||
| DE_PB_INT16 = 4; | |||
| DE_PB_UINT16 = 5; | |||
| DE_PB_INT32 = 6; | |||
| DE_PB_UINT32 = 7; | |||
| DE_PB_INT64 = 8; | |||
| DE_PB_UINT64 = 9; | |||
| DE_PB_FLOAT16 = 10; | |||
| DE_PB_FLOAT32 = 11; | |||
| DE_PB_FLOAT64 = 12; | |||
| DE_PB_STRING = 13; | |||
| } | |||
| message TensorPb { | |||
| repeated int64 dims = 1; // tensor shape info | |||
| DataTypePb tensor_type = 2; // tensor content data type | |||
| bytes data = 3; // tensor data | |||
| } | |||
| @@ -0,0 +1,134 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/core/tensor_row.h" | |||
| #include "minddata/dataset/engine/gnn/feature.h" | |||
| #include "minddata/dataset/engine/gnn/node.h" | |||
| #include "minddata/dataset/engine/gnn/edge.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| struct MetaInfo { | |||
| std::vector<NodeType> node_type; | |||
| std::vector<EdgeType> edge_type; | |||
| std::map<NodeType, NodeIdType> node_num; | |||
| std::map<EdgeType, EdgeIdType> edge_num; | |||
| std::vector<FeatureType> node_feature_type; | |||
| std::vector<FeatureType> edge_feature_type; | |||
| }; | |||
| class GraphData { | |||
| public: | |||
| // Get all nodes from the graph. | |||
| // @param NodeType node_type - type of node | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||
| // @return Status - The error code return | |||
| virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0; | |||
| // Get all edges from the graph. | |||
| // @param NodeType edge_type - type of edge | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The error code return | |||
| virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0; | |||
| // Get the node id from the edge. | |||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||
| // @return Status - The error code return | |||
| virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0; | |||
| // All neighbors of the acquisition node. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | |||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | |||
| // is not enough, fill in tensor as -1. | |||
| // @return Status - The error code return | |||
| virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) = 0; | |||
| // Get sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status - The error code return | |||
| virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; | |||
| // Get negative sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||
| // @return Status - The error code return | |||
| virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0; | |||
| // Node2vec random walk. | |||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||
| // @param std::vector<NodeType> meta_path - node type of each step | |||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status - The error code return | |||
| virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) = 0; | |||
| // Get the feature of a node | |||
| // @param std::shared_ptr<Tensor> nodes - List of nodes | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param TensorRow *out - Returned features | |||
| // @return Status - The error code return | |||
| virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) = 0; | |||
| // Get the feature of a edge | |||
| // @param std::shared_ptr<Tensor> edges - List of edges | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param Tensor *out - Returned features | |||
| // @return Status - The error code return | |||
| virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) = 0; | |||
| // Return meta information to python layer | |||
| virtual Status GraphInfo(py::dict *out) = 0; | |||
| virtual Status Init() = 0; | |||
| virtual Status Stop() = 0; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ | |||
| @@ -0,0 +1,589 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/graph_data_client.h" | |||
| #include <unistd.h> | |||
| #include <functional> | |||
| #include <map> | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "grpcpp/grpcpp.h" | |||
| #endif | |||
| #include "minddata/dataset/core/data_type.h" | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| GraphDataClient::GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port) | |||
| : dataset_file_(dataset_file), | |||
| host_(hostname), | |||
| port_(port), | |||
| pid_(0), | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| shared_memory_key_(-1), | |||
| shared_memory_size_(0), | |||
| graph_feature_parser_(nullptr), | |||
| graph_shared_memory_(nullptr), | |||
| #endif | |||
| registered_(false) { | |||
| } | |||
| GraphDataClient::~GraphDataClient() { (void)Stop(); } | |||
| Status GraphDataClient::Init() { | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS"); | |||
| #else | |||
| if (!registered_) { | |||
| std::string server_address; | |||
| server_address = host_ + ":" + std::to_string(port_); | |||
| MS_LOG(INFO) << "Graph data client starting. address:" << server_address; | |||
| pid_ = getpid(); | |||
| grpc::ChannelArguments args; | |||
| args.SetMaxReceiveMessageSize(-1); | |||
| std::shared_ptr<grpc::Channel> channel = | |||
| grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args); | |||
| stub_ = GnnGraphData::NewStub(channel); | |||
| Status status = RegisterToServer(); | |||
| while (status.ToString().find("Initializing") != std::string::npos) { | |||
| MS_LOG(INFO) << "Graph data server is initializing, please wait."; | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||
| status = RegisterToServer(); | |||
| } | |||
| RETURN_IF_NOT_OK(status); | |||
| MS_LOG(INFO) << "Graph data client successfully registered with server " << server_address; | |||
| } | |||
| RETURN_IF_NOT_OK(InitFeatureParser()); | |||
| return Status::OK(); | |||
| #endif | |||
| } | |||
| Status GraphDataClient::Stop() { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| if (registered_) { | |||
| UnRegisterToServer(); | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_ALL_NODES); | |||
| request.add_type(static_cast<google::protobuf::int32>(node_type)); | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_ALL_EDGES); | |||
| request.add_type(static_cast<google::protobuf::int32>(edge_type)); | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_NODES_FROM_EDGES); | |||
| for (const auto &edge_id : edge_list) { | |||
| request.add_id(static_cast<google::protobuf::int32>(edge_id)); | |||
| } | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_ALL_NEIGHBORS); | |||
| for (const auto &node_id : node_list) { | |||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||
| } | |||
| request.add_type(static_cast<google::protobuf::int32>(neighbor_type)); | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_SAMPLED_NEIGHBORS); | |||
| for (const auto &node_id : node_list) { | |||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||
| } | |||
| for (const auto &num : neighbor_nums) { | |||
| request.add_number(static_cast<google::protobuf::int32>(num)); | |||
| } | |||
| for (const auto &type : neighbor_types) { | |||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||
| } | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_NEG_SAMPLED_NEIGHBORS); | |||
| for (const auto &node_id : node_list) { | |||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||
| } | |||
| request.add_number(static_cast<google::protobuf::int32>(samples_num)); | |||
| request.add_type(static_cast<google::protobuf::int32>(neg_neighbor_type)); | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GraphDataClient::RandomWalk(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeType> &meta_path, float step_home_param, | |||
| float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(RANDOM_WALK); | |||
| for (const auto &node_id : node_list) { | |||
| request.add_id(static_cast<google::protobuf::int32>(node_id)); | |||
| } | |||
| for (const auto &type : meta_path) { | |||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||
| } | |||
| auto walk_param = request.mutable_random_walk(); | |||
| walk_param->set_p(step_home_param); | |||
| walk_param->set_q(step_away_param); | |||
| walk_param->set_default_id(static_cast<google::protobuf::int32>(default_node)); | |||
| RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, | |||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| if (!nodes || nodes->Size() == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_NODE_FEATURE); | |||
| for (const auto &type : feature_types) { | |||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||
| } | |||
| RETURN_IF_NOT_OK(TensorToPb(nodes, request.mutable_id_tensor())); | |||
| RETURN_IF_NOT_OK(GetGraphData(request, &response)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(), | |||
| "The number of feature types returned by the server is wrong"); | |||
| if (response.result_data().size() > 0) { | |||
| size_t i = 0; | |||
| for (const auto &result : response.result_data()) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); | |||
| std::shared_ptr<Tensor> fea_tensor; | |||
| RETURN_IF_NOT_OK(ParseNodeFeatureFromMemory(nodes, feature_types[i], tensor, &fea_tensor)); | |||
| out->emplace_back(std::move(fea_tensor)); | |||
| ++i; | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, | |||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| if (!edges || edges->Size() == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Input edges is empty"); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); | |||
| GnnGraphDataRequestPb request; | |||
| GnnGraphDataResponsePb response; | |||
| request.set_op_name(GET_EDGE_FEATURE); | |||
| for (const auto &type : feature_types) { | |||
| request.add_type(static_cast<google::protobuf::int32>(type)); | |||
| } | |||
| RETURN_IF_NOT_OK(TensorToPb(edges, request.mutable_id_tensor())); | |||
| RETURN_IF_NOT_OK(GetGraphData(request, &response)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(), | |||
| "The number of feature types returned by the server is wrong"); | |||
| if (response.result_data().size() > 0) { | |||
| size_t i = 0; | |||
| for (const auto &result : response.result_data()) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); | |||
| std::shared_ptr<Tensor> fea_tensor; | |||
| RETURN_IF_NOT_OK(ParseEdgeFeatureFromMemory(edges, feature_types[i], tensor, &fea_tensor)); | |||
| out->emplace_back(std::move(fea_tensor)); | |||
| ++i; | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GraphInfo(py::dict *out) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| RETURN_IF_NOT_OK(CheckPid()); | |||
| void *tag; | |||
| bool ok; | |||
| grpc::Status status; | |||
| grpc::ClientContext ctx; | |||
| grpc::CompletionQueue cq; | |||
| GnnMetaInfoRequestPb request; | |||
| GnnMetaInfoResponsePb response; | |||
| // One minute timeout | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||
| ctx.set_deadline(deadline); | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnMetaInfoResponsePb>> rpc( | |||
| stub_->PrepareAsyncGetMetaInfo(&ctx, request, &cq)); | |||
| rpc->StartCall(); | |||
| rpc->Finish(&response, &status, &response); | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| auto success = cq.Next(&tag, &ok); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||
| } | |||
| if (status.ok()) { | |||
| if (response.error_msg() != "Success") { | |||
| RETURN_STATUS_UNEXPECTED(response.error_msg()); | |||
| } else { | |||
| MetaInfo meta_info; | |||
| for (const auto &node : response.node_info()) { | |||
| meta_info.node_type.emplace_back(static_cast<NodeType>(node.type())); | |||
| meta_info.node_num[static_cast<NodeType>(node.type())] = static_cast<NodeIdType>(node.num()); | |||
| } | |||
| for (const auto &edge : response.edge_info()) { | |||
| meta_info.edge_type.emplace_back(static_cast<EdgeType>(edge.type())); | |||
| meta_info.edge_num[static_cast<EdgeType>(edge.type())] = static_cast<EdgeIdType>(edge.num()); | |||
| } | |||
| for (const auto &feature_type : response.node_feature_type()) { | |||
| meta_info.node_feature_type.emplace_back(static_cast<FeatureType>(feature_type)); | |||
| } | |||
| for (const auto &feature_type : response.edge_feature_type()) { | |||
| meta_info.edge_feature_type.emplace_back(static_cast<FeatureType>(feature_type)); | |||
| } | |||
| (*out)["node_type"] = py::cast(meta_info.node_type); | |||
| (*out)["edge_type"] = py::cast(meta_info.edge_type); | |||
| (*out)["node_num"] = py::cast(meta_info.node_num); | |||
| (*out)["edge_num"] = py::cast(meta_info.edge_num); | |||
| (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); | |||
| (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); | |||
| } | |||
| } else { | |||
| auto error_code = status.error_code(); | |||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||
| } | |||
| #endif | |||
| return Status::OK(); | |||
| } | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| Status GraphDataClient::GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response) { | |||
| RETURN_IF_NOT_OK(CheckPid()); | |||
| void *tag; | |||
| bool ok; | |||
| grpc::Status status; | |||
| grpc::ClientContext ctx; | |||
| grpc::CompletionQueue cq; | |||
| // One minute timeout | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||
| ctx.set_deadline(deadline); | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnGraphDataResponsePb>> rpc( | |||
| stub_->PrepareAsyncGetGraphData(&ctx, request, &cq)); | |||
| rpc->StartCall(); | |||
| rpc->Finish(response, &status, response); | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| auto success = cq.Next(&tag, &ok); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == response, "Expect the same tag"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||
| } | |||
| if (status.ok()) { | |||
| if (response->error_msg() != "Success") { | |||
| RETURN_STATUS_UNEXPECTED(response->error_msg()); | |||
| } | |||
| } else { | |||
| auto error_code = status.error_code(); | |||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response, | |||
| std::shared_ptr<Tensor> *out) { | |||
| RETURN_IF_NOT_OK(GetGraphData(request, response)); | |||
| if (1 == response->result_data().size()) { | |||
| const TensorPb &result = response->result_data()[0]; | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); | |||
| *out = std::move(tensor); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type, | |||
| const std::shared_ptr<Tensor> &memory_tensor, | |||
| std::shared_ptr<Tensor> *out) { | |||
| std::shared_ptr<Tensor> default_feature; | |||
| // If no feature can be obtained, fill in the default value | |||
| RETURN_IF_NOT_OK(GetNodeDefaultFeature(feature_type, &default_feature)); | |||
| TensorShape shape(default_feature->shape()); | |||
| auto shape_vec = nodes->shape().AsVector(); | |||
| dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>()); | |||
| shape = shape.PrependDim(size); | |||
| std::shared_ptr<Tensor> fea_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor)); | |||
| dsize_t index = 0; | |||
| auto fea_addr_itr = memory_tensor->begin<int64_t>(); | |||
| for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) { | |||
| int64_t offset = *fea_addr_itr; | |||
| fea_addr_itr++; | |||
| int64_t len = *fea_addr_itr; | |||
| fea_addr_itr++; | |||
| if (*node_itr == kDefaultNodeId || offset < 0 || len <= 0) { | |||
| RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature)); | |||
| } else { | |||
| uchar *start_addr_of_index = nullptr; | |||
| TensorShape remaining({-1}); | |||
| RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining)); | |||
| RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len)); | |||
| } | |||
| index++; | |||
| } | |||
| TensorShape reshape(nodes->shape()); | |||
| for (auto s : default_feature->shape().AsVector()) { | |||
| reshape = reshape.AppendDim(s); | |||
| } | |||
| RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); | |||
| fea_tensor->Squeeze(); | |||
| *out = std::move(fea_tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type, | |||
| const std::shared_ptr<Tensor> &memory_tensor, | |||
| std::shared_ptr<Tensor> *out) { | |||
| std::shared_ptr<Tensor> default_feature; | |||
| // If no feature can be obtained, fill in the default value | |||
| RETURN_IF_NOT_OK(GetEdgeDefaultFeature(feature_type, &default_feature)); | |||
| TensorShape shape(default_feature->shape()); | |||
| auto shape_vec = edges->shape().AsVector(); | |||
| dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>()); | |||
| shape = shape.PrependDim(size); | |||
| std::shared_ptr<Tensor> fea_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor)); | |||
| dsize_t index = 0; | |||
| auto fea_addr_itr = memory_tensor->begin<int64_t>(); | |||
| for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) { | |||
| int64_t offset = *fea_addr_itr; | |||
| fea_addr_itr++; | |||
| int64_t len = *fea_addr_itr; | |||
| fea_addr_itr++; | |||
| if (offset < 0 || len <= 0) { | |||
| RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature)); | |||
| } else { | |||
| uchar *start_addr_of_index = nullptr; | |||
| TensorShape remaining({-1}); | |||
| RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining)); | |||
| RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len)); | |||
| } | |||
| index++; | |||
| } | |||
| TensorShape reshape(edges->shape()); | |||
| for (auto s : default_feature->shape().AsVector()) { | |||
| reshape = reshape.AppendDim(s); | |||
| } | |||
| RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); | |||
| fea_tensor->Squeeze(); | |||
| *out = std::move(fea_tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) { | |||
| auto itr = default_node_feature_map_.find(feature_type); | |||
| if (itr == default_node_feature_map_.end()) { | |||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| *out_feature = itr->second; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) { | |||
| auto itr = default_edge_feature_map_.find(feature_type); | |||
| if (itr == default_edge_feature_map_.end()) { | |||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } else { | |||
| *out_feature = itr->second; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::RegisterToServer() { | |||
| RETURN_IF_NOT_OK(CheckPid()); | |||
| void *tag; | |||
| bool ok; | |||
| grpc::Status status; | |||
| grpc::ClientContext ctx; | |||
| grpc::CompletionQueue cq; | |||
| GnnClientRegisterRequestPb request; | |||
| GnnClientRegisterResponsePb response; | |||
| request.set_pid(static_cast<google::protobuf::int32>(pid_)); | |||
| // One minute timeout | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||
| ctx.set_deadline(deadline); | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientRegisterResponsePb>> rpc( | |||
| stub_->PrepareAsyncClientRegister(&ctx, request, &cq)); | |||
| rpc->StartCall(); | |||
| rpc->Finish(&response, &status, &response); | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| auto success = cq.Next(&tag, &ok); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||
| } | |||
| if (status.ok()) { | |||
| if (response.error_msg() == "Success") { | |||
| registered_ = true; | |||
| data_schema_ = mindrecord::json::parse(response.data_schema()); | |||
| shared_memory_key_ = static_cast<key_t>(response.shared_memory_key()); | |||
| shared_memory_size_ = response.shared_memory_size(); | |||
| MS_LOG(INFO) << "Register success, recv data_schema:" << response.data_schema(); | |||
| for (auto feature_info : response.default_node_feature()) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor)); | |||
| default_node_feature_map_[feature_info.type()] = tensor; | |||
| } | |||
| for (auto feature_info : response.default_edge_feature()) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor)); | |||
| default_edge_feature_map_[feature_info.type()] = tensor; | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED(response.error_msg()); | |||
| } | |||
| } else { | |||
| auto error_code = status.error_code(); | |||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::UnRegisterToServer() { | |||
| RETURN_IF_NOT_OK(CheckPid()); | |||
| MS_LOG(INFO) << "Graph data client send unregistered to server "; | |||
| void *tag; | |||
| bool ok; | |||
| grpc::Status status; | |||
| grpc::ClientContext ctx; | |||
| grpc::CompletionQueue cq; | |||
| GnnClientUnRegisterRequestPb request; | |||
| GnnClientUnRegisterResponsePb response; | |||
| request.set_pid(static_cast<google::protobuf::int32>(pid_)); | |||
| // One minute timeout | |||
| auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); | |||
| ctx.set_deadline(deadline); | |||
| std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientUnRegisterResponsePb>> rpc( | |||
| stub_->PrepareAsyncClientUnRegister(&ctx, request, &cq)); | |||
| rpc->StartCall(); | |||
| rpc->Finish(&response, &status, &response); | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| auto success = cq.Next(&tag, &ok); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); | |||
| } | |||
| if (status.ok()) { | |||
| if (response.error_msg() == "Success") { | |||
| MS_LOG(INFO) << "Unregister success."; | |||
| registered_ = false; | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED(response.error_msg()); | |||
| } | |||
| } else { | |||
| auto error_code = status.error_code(); | |||
| RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataClient::InitFeatureParser() { | |||
| // get shared memory | |||
| graph_shared_memory_ = std::make_unique<GraphSharedMemory>(shared_memory_size_, shared_memory_key_); | |||
| RETURN_IF_NOT_OK(graph_shared_memory_->GetSharedMemory()); | |||
| // build feature parser | |||
| graph_feature_parser_ = std::make_unique<GraphFeatureParser>(ShardColumn(data_schema_)); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,185 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <map> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include <utility> | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "proto/gnn_graph_data.grpc.pb.h" | |||
| #include "proto/gnn_graph_data.pb.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/gnn/graph_data.h" | |||
| #include "minddata/dataset/engine/gnn/graph_feature_parser.h" | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||
| #endif | |||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||
| #include "minddata/mindrecord/include/shard_column.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| class GraphDataClient : public GraphData { | |||
| public: | |||
| // Constructor | |||
| // @param std::string dataset_file - | |||
| // @param int32_t num_workers - number of parallel threads | |||
| GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port); | |||
| ~GraphDataClient(); | |||
| Status Init() override; | |||
| Status Stop() override; | |||
| // Get all nodes from the graph. | |||
| // @param NodeType node_type - type of node | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||
| // @return Status - The error code return | |||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get all edges from the graph. | |||
| // @param NodeType edge_type - type of edge | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The error code return | |||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get the node id from the edge. | |||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||
| // @return Status - The error code return | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||
| // All neighbors of the acquisition node. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is | |||
| // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors | |||
| // is not enough, fill in tensor as -1. | |||
| // @return Status - The error code return | |||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // Get sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop | |||
| // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status - The error code return | |||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||
| // Get negative sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| // @param NodeIdType samples_num - Number of neighbors sampled | |||
| // @param NodeType neg_neighbor_type - The type of negative neighbor. | |||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||
| // @return Status - The error code return | |||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | |||
| // Node2vec random walk. | |||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||
| // @param std::vector<NodeType> meta_path - node type of each step | |||
| // @param float step_home_param - return hyper parameter in node2vec algorithm | |||
| // @param float step_away_param - inout hyper parameter in node2vec algorithm | |||
| // @param NodeIdType default_node - default node id | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path | |||
| // @return Status - The error code return | |||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // Get the feature of a node | |||
| // @param std::shared_ptr<Tensor> nodes - List of nodes | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param TensorRow *out - Returned features | |||
| // @return Status - The error code return | |||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) override; | |||
| // Get the feature of a edge | |||
| // @param std::shared_ptr<Tensor> edges - List of edges | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param Tensor *out - Returned features | |||
| // @return Status - The error code return | |||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) override; | |||
| // Return meta information to python layer | |||
| Status GraphInfo(py::dict *out) override; | |||
| private: | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type, | |||
| const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out); | |||
| Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type, | |||
| const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out); | |||
| Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature); | |||
| Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature); | |||
| Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response); | |||
| Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response, | |||
| std::shared_ptr<Tensor> *out); | |||
| Status RegisterToServer(); | |||
| Status UnRegisterToServer(); | |||
| Status InitFeatureParser(); | |||
| Status CheckPid() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(), | |||
| "Multi-process mode is not supported, please change to use multi-thread"); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| std::string dataset_file_; | |||
| std::string host_; | |||
| int32_t port_; | |||
| int32_t pid_; | |||
| mindrecord::json data_schema_; | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| std::unique_ptr<GnnGraphData::Stub> stub_; | |||
| key_t shared_memory_key_; | |||
| int64_t shared_memory_size_; | |||
| std::unique_ptr<GraphFeatureParser> graph_feature_parser_; | |||
| std::unique_ptr<GraphSharedMemory> graph_shared_memory_; | |||
| std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_; | |||
| std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_; | |||
| #endif | |||
| bool registered_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/graph.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||
| #include <algorithm> | |||
| #include <functional> | |||
| @@ -22,19 +22,25 @@ | |||
| #include <utility> | |||
| #include "minddata/dataset/core/tensor_shape.h" | |||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| Graph::Graph(std::string dataset_file, int32_t num_workers) | |||
| : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { | |||
| GraphDataImpl::GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode) | |||
| : dataset_file_(dataset_file), | |||
| num_workers_(num_workers), | |||
| rnd_(GetRandomDevice()), | |||
| random_walk_(this), | |||
| server_mode_(server_mode) { | |||
| rnd_.seed(GetSeed()); | |||
| MS_LOG(INFO) << "num_workers:" << num_workers; | |||
| } | |||
| Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||
| GraphDataImpl::~GraphDataImpl() {} | |||
| Status GraphDataImpl::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||
| auto itr = node_type_map_.find(node_type); | |||
| if (itr == node_type_map_.end()) { | |||
| std::string err_msg = "Invalid node type:" + std::to_string(node_type); | |||
| @@ -46,8 +52,8 @@ Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { | |||
| } | |||
| template <typename T> | |||
| Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| Status GraphDataImpl::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| if (!type.IsCompatible<T>()) { | |||
| RETURN_STATUS_UNEXPECTED("Data type not compatible"); | |||
| } | |||
| @@ -72,7 +78,7 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data | |||
| } | |||
| template <typename T> | |||
| Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) { | |||
| Status GraphDataImpl::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) { | |||
| if (!data || data->empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Input data is empty"); | |||
| } | |||
| @@ -89,7 +95,7 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||
| Status GraphDataImpl::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||
| auto itr = edge_type_map_.find(edge_type); | |||
| if (itr == edge_type_map_.end()) { | |||
| std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); | |||
| @@ -100,7 +106,7 @@ Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||
| Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { | |||
| if (edge_list.empty()) { | |||
| RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); | |||
| } | |||
| @@ -122,8 +128,8 @@ Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::s | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); | |||
| @@ -143,7 +149,7 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::CheckSamplesNum(NodeIdType samples_num) { | |||
| Status GraphDataImpl::CheckSamplesNum(NodeIdType samples_num) { | |||
| NodeIdType all_nodes_number = | |||
| std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, | |||
| [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); | |||
| @@ -155,7 +161,7 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) { | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::CheckNeighborType(NodeType neighbor_type) { | |||
| Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) { | |||
| if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { | |||
| std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| @@ -163,9 +169,9 @@ Status Graph::CheckNeighborType(NodeType neighbor_type) { | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), | |||
| "The sizes of neighbor_nums and neighbor_types are inconsistent."); | |||
| @@ -205,8 +211,9 @@ Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data, | |||
| int32_t samples_num, std::vector<NodeIdType> *out_samples) { | |||
| Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, | |||
| const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, | |||
| std::vector<NodeIdType> *out_samples) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); | |||
| std::vector<NodeIdType> shuffled_id(data.size()); | |||
| std::iota(shuffled_id.begin(), shuffled_id.end(), 0); | |||
| @@ -223,8 +230,8 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||
| Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); | |||
| RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); | |||
| @@ -260,9 +267,9 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) { | |||
| Status GraphDataImpl::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out) { | |||
| RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); | |||
| std::vector<std::vector<NodeIdType>> walks; | |||
| RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); | |||
| @@ -270,7 +277,7 @@ Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::ve | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||
| Status GraphDataImpl::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||
| auto itr = default_node_feature_map_.find(feature_type); | |||
| if (itr == default_node_feature_map_.end()) { | |||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | |||
| @@ -281,7 +288,7 @@ Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||
| Status GraphDataImpl::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { | |||
| auto itr = default_edge_feature_map_.find(feature_type); | |||
| if (itr == default_edge_feature_map_.end()) { | |||
| std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); | |||
| @@ -292,8 +299,8 @@ Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) { | |||
| Status GraphDataImpl::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, | |||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||
| if (!nodes || nodes->Size() == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | |||
| } | |||
| @@ -339,8 +346,49 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) { | |||
| Status GraphDataImpl::GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| if (!nodes || nodes->Size() == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Input nodes is empty"); | |||
| } | |||
| TensorShape shape = nodes->shape().AppendDim(2); | |||
| std::shared_ptr<Tensor> fea_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor)); | |||
| auto out_fea_itr = fea_tensor->begin<int64_t>(); | |||
| for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) { | |||
| if (*node_itr == kDefaultNodeId) { | |||
| *out_fea_itr = -1; | |||
| ++out_fea_itr; | |||
| *out_fea_itr = -1; | |||
| ++out_fea_itr; | |||
| } else { | |||
| std::shared_ptr<Node> node; | |||
| RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); | |||
| std::shared_ptr<Feature> feature; | |||
| if (!node->GetFeatures(type, &feature).IsOk()) { | |||
| *out_fea_itr = -1; | |||
| ++out_fea_itr; | |||
| *out_fea_itr = -1; | |||
| ++out_fea_itr; | |||
| } else { | |||
| for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>(); | |||
| ++fea_itr) { | |||
| *out_fea_itr = *fea_itr; | |||
| ++out_fea_itr; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| fea_tensor->Squeeze(); | |||
| *out = std::move(fea_tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataImpl::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, | |||
| const std::vector<FeatureType> &feature_types, TensorRow *out) { | |||
| if (!edges || edges->Size() == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Input edges is empty"); | |||
| } | |||
| @@ -382,12 +430,45 @@ Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::ve | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::Init() { | |||
| Status GraphDataImpl::GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type, | |||
| std::shared_ptr<Tensor> *out) { | |||
| if (!edges || edges->Size() == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Input edges is empty"); | |||
| } | |||
| TensorShape shape = edges->shape().AppendDim(2); | |||
| std::shared_ptr<Tensor> fea_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor)); | |||
| auto out_fea_itr = fea_tensor->begin<int64_t>(); | |||
| for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) { | |||
| std::shared_ptr<Edge> edge; | |||
| RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); | |||
| std::shared_ptr<Feature> feature; | |||
| if (!edge->GetFeatures(type, &feature).IsOk()) { | |||
| *out_fea_itr = -1; | |||
| ++out_fea_itr; | |||
| *out_fea_itr = -1; | |||
| ++out_fea_itr; | |||
| } else { | |||
| for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>(); ++fea_itr) { | |||
| *out_fea_itr = *fea_itr; | |||
| ++out_fea_itr; | |||
| } | |||
| } | |||
| } | |||
| fea_tensor->Squeeze(); | |||
| *out = std::move(fea_tensor); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataImpl::Init() { | |||
| RETURN_IF_NOT_OK(LoadNodeAndEdge()); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetMetaInfo(MetaInfo *meta_info) { | |||
| Status GraphDataImpl::GetMetaInfo(MetaInfo *meta_info) { | |||
| meta_info->node_type.resize(node_type_map_.size()); | |||
| std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), | |||
| [](auto itr) { return itr.first; }); | |||
| @@ -427,7 +508,7 @@ Status Graph::GetMetaInfo(MetaInfo *meta_info) { | |||
| } | |||
| #ifdef ENABLE_PYTHON | |||
| Status Graph::GraphInfo(py::dict *out) { | |||
| Status GraphDataImpl::GraphInfo(py::dict *out) { | |||
| MetaInfo meta_info; | |||
| RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); | |||
| (*out)["node_type"] = py::cast(meta_info.node_type); | |||
| @@ -440,18 +521,16 @@ Status Graph::GraphInfo(py::dict *out) { | |||
| } | |||
| #endif | |||
| Status Graph::LoadNodeAndEdge() { | |||
| GraphLoader gl(dataset_file_, num_workers_); | |||
| Status GraphDataImpl::LoadNodeAndEdge() { | |||
| GraphLoader gl(this, dataset_file_, num_workers_, server_mode_); | |||
| // ask graph_loader to load everything into memory | |||
| RETURN_IF_NOT_OK(gl.InitAndLoad()); | |||
| // get all maps | |||
| RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, | |||
| &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, | |||
| &default_edge_feature_map_)); | |||
| RETURN_IF_NOT_OK(gl.GetNodesAndEdges()); | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||
| Status GraphDataImpl::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||
| auto itr = node_id_map_.find(id); | |||
| if (itr == node_id_map_.end()) { | |||
| std::string err_msg = "Invalid node id:" + std::to_string(id); | |||
| @@ -462,7 +541,7 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { | |||
| Status GraphDataImpl::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { | |||
| auto itr = edge_id_map_.find(id); | |||
| if (itr == edge_id_map_.end()) { | |||
| std::string err_msg = "Invalid edge id:" + std::to_string(id); | |||
| @@ -473,12 +552,13 @@ Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { | |||
| return Status::OK(); | |||
| } | |||
| Graph::RandomWalkBase::RandomWalkBase(Graph *graph) | |||
| GraphDataImpl::RandomWalkBase::RandomWalkBase(GraphDataImpl *graph) | |||
| : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} | |||
| Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, const NodeIdType default_node, | |||
| int32_t num_walks, int32_t num_workers) { | |||
| Status GraphDataImpl::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, | |||
| const std::vector<NodeType> &meta_path, float step_home_param, | |||
| float step_away_param, const NodeIdType default_node, int32_t num_walks, | |||
| int32_t num_workers) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); | |||
| node_list_ = node_list; | |||
| if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { | |||
| @@ -516,7 +596,7 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) { | |||
| Status GraphDataImpl::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) { | |||
| // Simulate a random walk starting from start node. | |||
| auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector | |||
| // walk simulate | |||
| @@ -556,8 +636,8 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { | |||
| for (int32_t i = 0; i < num_walks_; i++) { | |||
| Status GraphDataImpl::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { | |||
| for (int32_t i = 0; i < num_walks_; ++i) { | |||
| for (const auto &node : node_list_) { | |||
| std::vector<NodeIdType> walk; | |||
| RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); | |||
| @@ -567,8 +647,8 @@ Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, | |||
| std::shared_ptr<StochasticIndex> *node_probability) { | |||
| Status GraphDataImpl::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, | |||
| std::shared_ptr<StochasticIndex> *node_probability) { | |||
| // Generate alias nodes | |||
| std::shared_ptr<Node> node; | |||
| graph_->GetNodeByNodeId(node_id, &node); | |||
| @@ -581,8 +661,9 @@ Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, cons | |||
| return Status::OK(); | |||
| } | |||
| Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, | |||
| std::shared_ptr<StochasticIndex> *edge_probability) { | |||
| Status GraphDataImpl::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, | |||
| uint32_t meta_path_index, | |||
| std::shared_ptr<StochasticIndex> *edge_probability) { | |||
| // Get the alias edge setup lists for a given edge. | |||
| std::shared_ptr<Node> src_node; | |||
| graph_->GetNodeByNodeId(src, &src_node); | |||
| @@ -616,7 +697,7 @@ Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const No | |||
| return Status::OK(); | |||
| } | |||
| StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) { | |||
| StochasticIndex GraphDataImpl::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) { | |||
| uint32_t K = probability.size(); | |||
| std::vector<int32_t> switch_to_large_index(K, 0); | |||
| std::vector<float> weight(K, .0); | |||
| @@ -644,7 +725,7 @@ StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<flo | |||
| return StochasticIndex(switch_to_large_index, weight); | |||
| } | |||
| uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { | |||
| uint32_t GraphDataImpl::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { | |||
| auto switch_to_large_index = stochastic_index.first; | |||
| auto weight = stochastic_index.second; | |||
| const uint32_t size_of_index = switch_to_large_index.size(); | |||
| @@ -662,7 +743,7 @@ uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic | |||
| } | |||
| template <typename T> | |||
| std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) { | |||
| std::vector<float> GraphDataImpl::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) { | |||
| float sum_probability = | |||
| 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); | |||
| if (sum_probability < kGnnEpsilon) { | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| @@ -25,13 +25,11 @@ | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/core/tensor_row.h" | |||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||
| #include "minddata/dataset/engine/gnn/feature.h" | |||
| #include "minddata/dataset/engine/gnn/node.h" | |||
| #include "minddata/dataset/engine/gnn/edge.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data.h" | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||
| #endif | |||
| #include "minddata/mindrecord/include/common/shard_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001; | |||
| const uint32_t kMaxNumWalks = 80; | |||
| using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>; | |||
| struct MetaInfo { | |||
| std::vector<NodeType> node_type; | |||
| std::vector<EdgeType> edge_type; | |||
| std::map<NodeType, NodeIdType> node_num; | |||
| std::map<EdgeType, EdgeIdType> edge_num; | |||
| std::vector<FeatureType> node_feature_type; | |||
| std::vector<FeatureType> edge_feature_type; | |||
| }; | |||
| class Graph { | |||
| class GraphDataImpl : public GraphData { | |||
| public: | |||
| // Constructor | |||
| // @param std::string dataset_file - | |||
| // @param int32_t num_workers - number of parallel threads | |||
| Graph(std::string dataset_file, int32_t num_workers); | |||
| GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false); | |||
| ~Graph() = default; | |||
| ~GraphDataImpl(); | |||
| // Get all nodes from the graph. | |||
| // @param NodeType node_type - type of node | |||
| // @param std::shared_ptr<Tensor> *out - Returned nodes id | |||
| // @return Status - The error code return | |||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out); | |||
| Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get all edges from the graph. | |||
| // @param NodeType edge_type - type of edge | |||
| // @param std::shared_ptr<Tensor> *out - Returned edge ids | |||
| // @return Status - The error code return | |||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out); | |||
| Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override; | |||
| // Get the node id from the edge. | |||
| // @param std::vector<EdgeIdType> edge_list - List of edges | |||
| // @param std::shared_ptr<Tensor> *out - Returned node ids | |||
| // @return Status - The error code return | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out); | |||
| Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override; | |||
| // All neighbors of the acquisition node. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| @@ -85,7 +74,7 @@ class Graph { | |||
| // is not enough, fill in tensor as -1. | |||
| // @return Status - The error code return | |||
| Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, | |||
| std::shared_ptr<Tensor> *out); | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // Get sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| @@ -94,7 +83,7 @@ class Graph { | |||
| // @param std::shared_ptr<Tensor> *out - Returned neighbor's id. | |||
| // @return Status - The error code return | |||
| Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); | |||
| const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; | |||
| // Get negative sampled neighbors. | |||
| // @param std::vector<NodeType> node_list - List of nodes | |||
| @@ -103,7 +92,7 @@ class Graph { | |||
| // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. | |||
| // @return Status - The error code return | |||
| Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); | |||
| NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override; | |||
| // Node2vec random walk. | |||
| // @param std::vector<NodeIdType> node_list - List of nodes | |||
| @@ -115,7 +104,7 @@ class Graph { | |||
| // @return Status - The error code return | |||
| Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param, float step_away_param, NodeIdType default_node, | |||
| std::shared_ptr<Tensor> *out); | |||
| std::shared_ptr<Tensor> *out) override; | |||
| // Get the feature of a node | |||
| // @param std::shared_ptr<Tensor> nodes - List of nodes | |||
| @@ -124,16 +113,22 @@ class Graph { | |||
| // @param TensorRow *out - Returned features | |||
| // @return Status - The error code return | |||
| Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out); | |||
| TensorRow *out) override; | |||
| Status GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type, | |||
| std::shared_ptr<Tensor> *out); | |||
| // Get the feature of a edge | |||
| // @param std::shared_ptr<Tensor> edget - List of edges | |||
| // @param std::shared_ptr<Tensor> edges - List of edges | |||
| // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type | |||
| // does not exist. | |||
| // @param Tensor *out - Returned features | |||
| // @return Status - The error code return | |||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out); | |||
| Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, | |||
| TensorRow *out) override; | |||
| Status GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type, | |||
| std::shared_ptr<Tensor> *out); | |||
| // Get meta information of graph | |||
| // @param MetaInfo *meta_info - Returned meta information | |||
| @@ -142,15 +137,34 @@ class Graph { | |||
| #ifdef ENABLE_PYTHON | |||
| // Return meta information to python layer | |||
| Status GraphInfo(py::dict *out); | |||
| Status GraphInfo(py::dict *out) override; | |||
| #endif | |||
| Status Init(); | |||
| const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() { | |||
| return &default_node_feature_map_; | |||
| } | |||
| const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() { | |||
| return &default_edge_feature_map_; | |||
| } | |||
| Status Init() override; | |||
| Status Stop() override { return Status::OK(); } | |||
| std::string GetDataSchema() { return data_schema_.dump(); } | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| key_t GetSharedMemoryKey() { return graph_shared_memory_->memory_key(); } | |||
| int64_t GetSharedMemorySize() { return graph_shared_memory_->memory_size(); } | |||
| #endif | |||
| private: | |||
| friend class GraphLoader; | |||
| class RandomWalkBase { | |||
| public: | |||
| explicit RandomWalkBase(Graph *graph); | |||
| explicit RandomWalkBase(GraphDataImpl *graph); | |||
| Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, | |||
| float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, | |||
| @@ -176,7 +190,7 @@ class Graph { | |||
| template <typename T> | |||
| std::vector<float> Normalize(const std::vector<T> &non_normalized_probability); | |||
| Graph *graph_; | |||
| GraphDataImpl *graph_; | |||
| std::vector<NodeIdType> node_list_; | |||
| std::vector<NodeType> meta_path_; | |||
| float step_home_param_; // Return hyper parameter. Default is 1.0 | |||
| @@ -248,7 +262,11 @@ class Graph { | |||
| int32_t num_workers_; // The number of worker threads | |||
| std::mt19937 rnd_; | |||
| RandomWalkBase random_walk_; | |||
| mindrecord::json data_schema_; | |||
| bool server_mode_; | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| std::unique_ptr<GraphSharedMemory> graph_shared_memory_; | |||
| #endif | |||
| std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; | |||
| std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; | |||
| @@ -264,4 +282,4 @@ class Graph { | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/graph_data_server.h" | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <iterator> | |||
| #include <numeric> | |||
| #include <utility> | |||
| #include "minddata/dataset/core/tensor_shape.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, | |||
| int32_t port, int32_t client_num, bool auto_shutdown) | |||
| : dataset_file_(dataset_file), | |||
| num_workers_(num_workers), | |||
| client_num_(client_num), | |||
| max_connected_client_num_(0), | |||
| auto_shutdown_(auto_shutdown), | |||
| state_(kGdsUninit) { | |||
| tg_ = std::make_unique<TaskGroup>(); | |||
| graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true); | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get()); | |||
| async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get()); | |||
| #endif | |||
| } | |||
| Status GraphDataServer::Init() { | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS"); | |||
| #else | |||
| set_state(kGdsInitializing); | |||
| RETURN_IF_NOT_OK(async_server_->Run()); | |||
| // RETURN_IF_NOT_OK(InitGraphDataImpl()); | |||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this))); | |||
| for (int32_t i = 0; i < num_workers_; ++i) { | |||
| RETURN_IF_NOT_OK( | |||
| tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this))); | |||
| } | |||
| if (auto_shutdown_) { | |||
| RETURN_IF_NOT_OK( | |||
| tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this))); | |||
| } | |||
| return Status::OK(); | |||
| #endif | |||
| } | |||
| Status GraphDataServer::InitGraphDataImpl() { | |||
| TaskManager::FindMe()->Post(); | |||
| Status s = graph_data_impl_->Init(); | |||
| if (s.IsOk()) { | |||
| set_state(kGdsRunning); | |||
| } else { | |||
| (void)Stop(); | |||
| } | |||
| return s; | |||
| } | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| Status GraphDataServer::StartAsyncRpcService() { | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(async_server_->HandleRequest()); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status GraphDataServer::JudgeAutoShutdownServer() { | |||
| TaskManager::FindMe()->Post(); | |||
| while (true) { | |||
| if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) { | |||
| MS_LOG(INFO) << "All clients have been unregister, automatically exit the server."; | |||
| RETURN_IF_NOT_OK(Stop()); | |||
| break; | |||
| } | |||
| if (state_ == kGdsStopped) { | |||
| break; | |||
| } | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(1000)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServer::Stop() { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| async_server_->Stop(); | |||
| #endif | |||
| set_state(kGdsStopped); | |||
| graph_data_impl_.reset(); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServer::ClientRegister(int32_t pid) { | |||
| std::unique_lock<std::mutex> lck(mutex_); | |||
| MS_LOG(INFO) << "client register pid:" << std::to_string(pid); | |||
| client_pid_.emplace(pid); | |||
| if (client_pid_.size() > max_connected_client_num_) { | |||
| max_connected_client_num_ = client_pid_.size(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServer::ClientUnRegister(int32_t pid) { | |||
| std::unique_lock<std::mutex> lck(mutex_); | |||
| auto itr = client_pid_.find(pid); | |||
| if (itr != client_pid_.end()) { | |||
| client_pid_.erase(itr); | |||
| MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,196 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <unordered_set> | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "grpcpp/grpcpp.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_service_impl.h" | |||
| #include "minddata/dataset/engine/gnn/grpc_async_server.h" | |||
| #endif | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| class GraphDataImpl; | |||
| class GraphDataServer { | |||
| public: | |||
| enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; | |||
| GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port, | |||
| int32_t client_num, bool auto_shutdown); | |||
| ~GraphDataServer() = default; | |||
| Status Init(); | |||
| Status Stop(); | |||
| Status ClientRegister(int32_t pid); | |||
| Status ClientUnRegister(int32_t pid); | |||
| enum ServerState state() { return state_; } | |||
| bool IsStoped() { | |||
| if (state_ == kGdsStopped) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| private: | |||
| void set_state(enum ServerState state) { state_ = state; } | |||
| Status InitGraphDataImpl(); | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| Status StartAsyncRpcService(); | |||
| #endif | |||
| Status JudgeAutoShutdownServer(); | |||
| std::string dataset_file_; | |||
| int32_t num_workers_; // The number of worker threads | |||
| int32_t client_num_; | |||
| int32_t max_connected_client_num_; | |||
| bool auto_shutdown_; | |||
| enum ServerState state_; | |||
| std::unique_ptr<TaskGroup> tg_; // Class for worker management | |||
| std::unique_ptr<GraphDataImpl> graph_data_impl_; | |||
| std::unordered_set<int32_t> client_pid_; | |||
| std::mutex mutex_; | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| std::unique_ptr<GraphDataServiceImpl> service_impl_; | |||
| std::unique_ptr<GrpcAsyncServer> async_server_; | |||
| #endif | |||
| }; | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| class UntypedCall { | |||
| public: | |||
| virtual ~UntypedCall() {} | |||
| virtual Status operator()() = 0; | |||
| }; | |||
| template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage> | |||
| class CallData : public UntypedCall { | |||
| public: | |||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | |||
| using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *, | |||
| grpc::ServerAsyncResponseWriter<ResponseMessage> *, | |||
| grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *); | |||
| using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *, | |||
| ResponseMessage *); | |||
| CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq, | |||
| EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) | |||
| : status_(STATE::CREATE), | |||
| service_impl_(service_impl), | |||
| async_service_(async_service), | |||
| cq_(cq), | |||
| enqueue_function_(enqueue_function), | |||
| handle_request_function_(handle_request_function), | |||
| responder_(&ctx_) {} | |||
| ~CallData() = default; | |||
| static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq, | |||
| EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) { | |||
| auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>( | |||
| service_impl, async_service, cq, enqueue_function, handle_request_function); | |||
| RETURN_IF_NOT_OK((*call)()); | |||
| return Status::OK(); | |||
| } | |||
| Status operator()() { | |||
| if (status_ == STATE::CREATE) { | |||
| status_ = STATE::PROCESS; | |||
| (async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } else if (status_ == STATE::PROCESS) { | |||
| EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_); | |||
| status_ = STATE::FINISH; | |||
| // new CallData(service_, cq_, this->s_type_); | |||
| grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, s, this); | |||
| } else { | |||
| GPR_ASSERT(status_ == STATE::FINISH); | |||
| delete this; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| STATE status_; | |||
| ServiceImpl *service_impl_; | |||
| AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| EnqueueFunction enqueue_function_; | |||
| HandleRequestFunction handle_request_function_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<ResponseMessage> responder_; | |||
| RequestMessage request_; | |||
| ResponseMessage response_; | |||
| }; | |||
| #define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \ | |||
| do { \ | |||
| Status s = \ | |||
| CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \ | |||
| service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \ | |||
| &gnn::GraphDataServiceImpl::method); \ | |||
| RETURN_IF_NOT_OK(s); \ | |||
| } while (0) | |||
| class GraphDataGrpcServer : public GrpcAsyncServer { | |||
| public: | |||
| GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl) | |||
| : GrpcAsyncServer(host, port), service_impl_(service_impl) {} | |||
| Status RegisterService(grpc::ServerBuilder *builder) { | |||
| builder->RegisterService(&svc_); | |||
| return Status::OK(); | |||
| } | |||
| Status EnqueueRequest() { | |||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb, | |||
| GnnClientRegisterResponsePb); | |||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb, | |||
| GnnClientUnRegisterResponsePb); | |||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb); | |||
| ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb); | |||
| return Status::OK(); | |||
| } | |||
| Status ProcessRequest(void *tag) { | |||
| auto rq = static_cast<UntypedCall *>(tag); | |||
| RETURN_IF_NOT_OK((*rq)()); | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| GraphDataServiceImpl *service_impl_; | |||
| GnnGraphData::AsyncService svc_; | |||
| }; | |||
| #endif | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ | |||
| @@ -0,0 +1,299 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/graph_data_service_impl.h" | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_server.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| using pFunction = Status (GraphDataServiceImpl::*)(const GnnGraphDataRequestPb *, GnnGraphDataResponsePb *); | |||
| static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = { | |||
| {GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes}, | |||
| {GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges}, | |||
| {GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges}, | |||
| {GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors}, | |||
| {GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors}, | |||
| {GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors}, | |||
| {RANDOM_WALK, &GraphDataServiceImpl::RandomWalk}, | |||
| {GET_NODE_FEATURE, &GraphDataServiceImpl::GetNodeFeature}, | |||
| {GET_EDGE_FEATURE, &GraphDataServiceImpl::GetEdgeFeature}}; | |||
| GraphDataServiceImpl::GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl) | |||
| : server_(server), graph_data_impl_(graph_data_impl) {} | |||
| Status GraphDataServiceImpl::FillDefaultFeature(GnnClientRegisterResponsePb *response) { | |||
| const auto default_node_features = graph_data_impl_->GetAllDefaultNodeFeatures(); | |||
| for (const auto feature : *default_node_features) { | |||
| GnnFeatureInfoPb *feature_info = response->add_default_node_feature(); | |||
| feature_info->set_type(feature.first); | |||
| RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature())); | |||
| } | |||
| const auto default_edge_features = graph_data_impl_->GetAllDefaultEdgeFeatures(); | |||
| for (const auto feature : *default_edge_features) { | |||
| GnnFeatureInfoPb *feature_info = response->add_default_edge_feature(); | |||
| feature_info->set_type(feature.first); | |||
| RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature())); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context, | |||
| const GnnClientRegisterRequestPb *request, | |||
| GnnClientRegisterResponsePb *response) { | |||
| Status s = server_->ClientRegister(request->pid()); | |||
| if (s.IsOk()) { | |||
| switch (server_->state()) { | |||
| case GraphDataServer::kGdsUninit: | |||
| case GraphDataServer::kGdsInitializing: | |||
| response->set_error_msg("Initializing"); | |||
| break; | |||
| case GraphDataServer::kGdsRunning: | |||
| response->set_error_msg("Success"); | |||
| response->set_data_schema(graph_data_impl_->GetDataSchema()); | |||
| response->set_shared_memory_key(graph_data_impl_->GetSharedMemoryKey()); | |||
| response->set_shared_memory_size(graph_data_impl_->GetSharedMemorySize()); | |||
| s = FillDefaultFeature(response); | |||
| if (!s.IsOk()) { | |||
| response->set_error_msg(s.ToString()); | |||
| } | |||
| break; | |||
| case GraphDataServer::kGdsStopped: | |||
| response->set_error_msg("Stoped"); | |||
| break; | |||
| } | |||
| } else { | |||
| response->set_error_msg(s.ToString()); | |||
| } | |||
| return ::grpc::Status::OK; | |||
| } | |||
| grpc::Status GraphDataServiceImpl::ClientUnRegister(grpc::ServerContext *context, | |||
| const GnnClientUnRegisterRequestPb *request, | |||
| GnnClientUnRegisterResponsePb *response) { | |||
| Status s = server_->ClientUnRegister(request->pid()); | |||
| if (s.IsOk()) { | |||
| response->set_error_msg("Success"); | |||
| } else { | |||
| response->set_error_msg(s.ToString()); | |||
| } | |||
| return ::grpc::Status::OK; | |||
| } | |||
| grpc::Status GraphDataServiceImpl::GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request, | |||
| GnnGraphDataResponsePb *response) { | |||
| // MS_LOG(INFO) << "#### receive GetGraphData:" << request->op_name(); | |||
| Status s; | |||
| auto iter = g_get_graph_data_func_.find(request->op_name()); | |||
| if (iter != g_get_graph_data_func_.end()) { | |||
| pFunction func = iter->second; | |||
| s = (this->*func)(request, response); | |||
| if (s.IsOk()) { | |||
| response->set_error_msg("Success"); | |||
| } else { | |||
| response->set_error_msg(s.ToString()); | |||
| } | |||
| } else { | |||
| response->set_error_msg("Invalid op name."); | |||
| } | |||
| // MS_LOG(INFO) << "#### end receive GetGraphData:" << request->op_name(); | |||
| return ::grpc::Status::OK; | |||
| } | |||
| grpc::Status GraphDataServiceImpl::GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request, | |||
| GnnMetaInfoResponsePb *response) { | |||
| MetaInfo meta_info; | |||
| Status s = graph_data_impl_->GetMetaInfo(&meta_info); | |||
| if (s.IsOk()) { | |||
| response->set_error_msg("Success"); | |||
| for (const auto &type : meta_info.node_type) { | |||
| auto node_info = response->add_node_info(); | |||
| node_info->set_type(static_cast<google::protobuf::int32>(type)); | |||
| auto itr = meta_info.node_num.find(type); | |||
| if (itr != meta_info.node_num.end()) { | |||
| node_info->set_num(static_cast<google::protobuf::int32>(itr->second)); | |||
| } else { | |||
| node_info->set_num(0); | |||
| } | |||
| } | |||
| for (const auto &type : meta_info.edge_type) { | |||
| auto edge_info = response->add_edge_info(); | |||
| edge_info->set_type(static_cast<google::protobuf::int32>(type)); | |||
| auto itr = meta_info.edge_num.find(type); | |||
| if (itr != meta_info.edge_num.end()) { | |||
| edge_info->set_num(static_cast<google::protobuf::int32>(itr->second)); | |||
| } else { | |||
| edge_info->set_num(0); | |||
| } | |||
| } | |||
| for (const auto &type : meta_info.node_feature_type) { | |||
| response->add_node_feature_type(static_cast<google::protobuf::int32>(type)); | |||
| } | |||
| for (const auto &type : meta_info.edge_feature_type) { | |||
| response->add_edge_feature_type(static_cast<google::protobuf::int32>(type)); | |||
| } | |||
| } else { | |||
| response->set_error_msg(s.ToString()); | |||
| } | |||
| return ::grpc::Status::OK; | |||
| } | |||
| Status GraphDataServiceImpl::GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetAllNodes(static_cast<NodeType>(request->type()[0]), &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetAllEdges(static_cast<EdgeType>(request->type()[0]), &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input edge id is empty"); | |||
| std::vector<EdgeIdType> edge_list; | |||
| edge_list.resize(request->id().size()); | |||
| std::transform(request->id().begin(), request->id().end(), edge_list.begin(), | |||
| [](const google::protobuf::int32 id) { return static_cast<EdgeIdType>(id); }); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetNodesFromEdges(edge_list, &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); | |||
| std::vector<NodeIdType> node_list; | |||
| node_list.resize(request->id().size()); | |||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *request, | |||
| GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() > 0, "The input neighbor number is empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input neighbor type is empty"); | |||
| std::vector<NodeIdType> node_list; | |||
| node_list.resize(request->id().size()); | |||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||
| std::vector<NodeIdType> neighbor_nums; | |||
| neighbor_nums.resize(request->number().size()); | |||
| std::transform(request->number().begin(), request->number().end(), neighbor_nums.begin(), | |||
| [](const google::protobuf::int32 num) { return static_cast<NodeIdType>(num); }); | |||
| std::vector<NodeType> neighbor_types; | |||
| neighbor_types.resize(request->type().size()); | |||
| std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(), | |||
| [](const google::protobuf::int32 type) { return static_cast<NodeType>(type); }); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, | |||
| GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() == 1, "The number of neighbor number is not 1"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of neighbor types is not 1"); | |||
| std::vector<NodeIdType> node_list; | |||
| node_list.resize(request->id().size()); | |||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetNegSampledNeighbors(node_list, static_cast<NodeIdType>(request->number()[0]), | |||
| static_cast<NodeType>(request->type()[0]), &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input meta path is empty"); | |||
| std::vector<NodeIdType> node_list; | |||
| node_list.resize(request->id().size()); | |||
| std::transform(request->id().begin(), request->id().end(), node_list.begin(), | |||
| [](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); }); | |||
| std::vector<NodeType> meta_path; | |||
| meta_path.resize(request->type().size()); | |||
| std::transform(request->type().begin(), request->type().end(), meta_path.begin(), | |||
| [](const google::protobuf::int32 type) { return static_cast<NodeType>(type); }); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->RandomWalk(node_list, meta_path, request->random_walk().p(), | |||
| request->random_walk().q(), request->random_walk().default_id(), | |||
| &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| std::shared_ptr<Tensor> nodes; | |||
| RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &nodes)); | |||
| for (const auto &type : request->type()) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetNodeFeatureSharedMemory(nodes, type, &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphDataServiceImpl::GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { | |||
| std::shared_ptr<Tensor> edges; | |||
| RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &edges)); | |||
| for (const auto &type : request->type()) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(graph_data_impl_->GetEdgeFeatureSharedMemory(edges, type, &tensor)); | |||
| TensorPb *result = response->add_result_data(); | |||
| RETURN_IF_NOT_OK(TensorToPb(tensor, result)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||
| #include "proto/gnn_graph_data.grpc.pb.h" | |||
| #include "proto/gnn_graph_data.pb.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| class GraphDataServer; | |||
| // class GraphDataServiceImpl : public GnnGraphData::Service { | |||
| class GraphDataServiceImpl { | |||
| public: | |||
| GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl); | |||
| ~GraphDataServiceImpl() = default; | |||
| grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request, | |||
| GnnClientRegisterResponsePb *response); | |||
| grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request, | |||
| GnnClientUnRegisterResponsePb *response); | |||
| grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request, | |||
| GnnGraphDataResponsePb *response); | |||
| grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request, | |||
| GnnMetaInfoResponsePb *response); | |||
| Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); | |||
| private: | |||
| Status FillDefaultFeature(GnnClientRegisterResponsePb *response); | |||
| GraphDataServer *server_; | |||
| GraphDataImpl *graph_data_impl_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ | |||
| @@ -0,0 +1,106 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/graph_feature_parser.h" | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| using mindrecord::MSRStatus; | |||
| GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) { | |||
| shard_column_ = std::make_unique<ShardColumn>(shard_column); | |||
| } | |||
| Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob, | |||
| std::shared_ptr<Tensor> *tensor) { | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<unsigned char[]> data_ptr; | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), | |||
| std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), | |||
| data, tensor)); | |||
| return Status::OK(); | |||
| } | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob, | |||
| GraphSharedMemory *shared_memory, | |||
| std::shared_ptr<Tensor> *out_tensor) { | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<unsigned char[]> data_ptr; | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); | |||
| auto fea_itr = tensor->begin<int64_t>(); | |||
| int64_t offset = 0; | |||
| RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset)); | |||
| *fea_itr = offset; | |||
| ++fea_itr; | |||
| *fea_itr = n_bytes; | |||
| *out_tensor = std::move(tensor); | |||
| return Status::OK(); | |||
| } | |||
| #endif | |||
| Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob, | |||
| std::vector<int32_t> *indices) { | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<unsigned char[]> data_ptr; | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| for (int i = 0; i < n_bytes; i += col_type_size) { | |||
| int32_t feature_ind = -1; | |||
| if (col_type == mindrecord::ColumnInt32) { | |||
| feature_ind = *(reinterpret_cast<const int32_t *>(data + i)); | |||
| } else if (col_type == mindrecord::ColumnInt64) { | |||
| feature_ind = *(reinterpret_cast<const int64_t *>(data + i)); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); | |||
| } | |||
| if (feature_ind >= 0) indices->push_back(feature_ind); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ | |||
| #include <memory> | |||
| #include <queue> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/data_type.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/gnn/feature.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/mindrecord/include/shard_column.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| using mindrecord::ShardColumn; | |||
| class GraphFeatureParser { | |||
| public: | |||
| explicit GraphFeatureParser(const ShardColumn &shard_column); | |||
| ~GraphFeatureParser() = default; | |||
| // @param std::string key - column name | |||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||
| // @param std::vector<int32_t> *ind - return value, list of feature index in int32_t | |||
| // @return Status - the status code | |||
| Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind); | |||
| // @param std::string &key - column name | |||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||
| // @param std::shared_ptr<Tensor> *tensor - return value feature tensor | |||
| // @return Status - the status code | |||
| Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor); | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob, | |||
| GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor); | |||
| #endif | |||
| private: | |||
| std::unique_ptr<ShardColumn> shard_column_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ | |||
| @@ -13,41 +13,42 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||
| #include <future> | |||
| #include <tuple> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||
| #include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||
| #include "minddata/dataset/engine/gnn/local_edge.h" | |||
| #include "minddata/dataset/engine/gnn/local_node.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #include "minddata/mindrecord/include/shard_error.h" | |||
| using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| using mindrecord::MSRStatus; | |||
| GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) | |||
| : mr_path_(mr_filepath), | |||
| GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers, bool server_mode) | |||
| : graph_impl_(graph_impl), | |||
| mr_path_(mr_filepath), | |||
| num_workers_(num_workers), | |||
| row_id_(0), | |||
| shard_reader_(nullptr), | |||
| graph_feature_parser_(nullptr), | |||
| keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} | |||
| Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, | |||
| EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, | |||
| EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, | |||
| DefaultEdgeFeatureMap *default_edge_feature_map) { | |||
| Status GraphLoader::GetNodesAndEdges() { | |||
| NodeIdMap *n_id_map = &graph_impl_->node_id_map_; | |||
| EdgeIdMap *e_id_map = &graph_impl_->edge_id_map_; | |||
| for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) { | |||
| while (dq.empty() == false) { | |||
| std::shared_ptr<Node> node_ptr = dq.front(); | |||
| n_id_map->insert({node_ptr->id(), node_ptr}); | |||
| (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); | |||
| graph_impl_->node_type_map_[node_ptr->type()].push_back(node_ptr->id()); | |||
| dq.pop_front(); | |||
| } | |||
| } | |||
| @@ -63,15 +64,15 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N | |||
| RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); | |||
| RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); | |||
| e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ | |||
| (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); | |||
| graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); | |||
| dq.pop_front(); | |||
| } | |||
| } | |||
| for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); | |||
| for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); | |||
| for (auto &itr : graph_impl_->node_type_map_) itr.second.shrink_to_fit(); | |||
| for (auto &itr : graph_impl_->edge_type_map_) itr.second.shrink_to_fit(); | |||
| MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); | |||
| MergeFeatureMaps(); | |||
| return Status::OK(); | |||
| } | |||
| @@ -92,13 +93,26 @@ Status GraphLoader::InitAndLoad() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); | |||
| mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; | |||
| graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); | |||
| mindrecord::json schema = graph_impl_->data_schema_["schema"]; | |||
| for (const std::string &key : keys_) { | |||
| if (schema.find(key) == schema.end()) { | |||
| RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); | |||
| } | |||
| } | |||
| if (graph_impl_->server_mode_) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| int64_t total_blob_size = 0; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS, | |||
| "failed to get total blob size"); | |||
| graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_); | |||
| RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); | |||
| #endif | |||
| } | |||
| graph_feature_parser_ = std::make_unique<GraphFeatureParser>(*shard_reader_->GetShardColumn()); | |||
| // launching worker threads | |||
| for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { | |||
| RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); | |||
| @@ -116,18 +130,39 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec | |||
| NodeType node_type = static_cast<NodeType>(col_jsn["type"]); | |||
| (*node) = std::make_shared<LocalNode>(node_id, node_type); | |||
| std::vector<int32_t> indices; | |||
| RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); | |||
| for (int32_t ind : indices) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); | |||
| RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||
| (*feature_map)[node_type].insert(ind); | |||
| if ((*default_feature)[ind] == nullptr) { | |||
| std::shared_ptr<Tensor> zero_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices)); | |||
| if (graph_impl_->server_mode_) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| for (int32_t ind : indices) { | |||
| std::shared_ptr<Tensor> tensor_sm; | |||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory( | |||
| "node_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm)); | |||
| RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true))); | |||
| (*feature_map)[node_type].insert(ind); | |||
| if ((*default_feature)[ind] == nullptr) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK( | |||
| graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor)); | |||
| std::shared_ptr<Tensor> zero_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||
| } | |||
| } | |||
| #endif | |||
| } else { | |||
| for (int32_t ind : indices) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK( | |||
| graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor)); | |||
| RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||
| (*feature_map)[node_type].insert(ind); | |||
| if ((*default_feature)[ind] == nullptr) { | |||
| std::shared_ptr<Tensor> zero_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| @@ -143,63 +178,42 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec | |||
| std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1); | |||
| (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst); | |||
| std::vector<int32_t> indices; | |||
| RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); | |||
| for (int32_t ind : indices) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); | |||
| RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||
| (*feature_map)[edge_type].insert(ind); | |||
| if ((*default_feature)[ind] == nullptr) { | |||
| std::shared_ptr<Tensor> zero_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices)); | |||
| if (graph_impl_->server_mode_) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| for (int32_t ind : indices) { | |||
| std::shared_ptr<Tensor> tensor_sm; | |||
| RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory( | |||
| "edge_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm)); | |||
| RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true))); | |||
| (*feature_map)[edge_type].insert(ind); | |||
| if ((*default_feature)[ind] == nullptr) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK( | |||
| graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor)); | |||
| std::shared_ptr<Tensor> zero_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob, | |||
| const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) { | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<unsigned char[]> data_ptr; | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( | |||
| key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), | |||
| std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), | |||
| data, tensor)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob, | |||
| const mindrecord::json &col_jsn, std::vector<int32_t> *indices) { | |||
| const unsigned char *data = nullptr; | |||
| std::unique_ptr<unsigned char[]> data_ptr; | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( | |||
| key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| for (int i = 0; i < n_bytes; i += col_type_size) { | |||
| int32_t feature_ind = -1; | |||
| if (col_type == mindrecord::ColumnInt32) { | |||
| feature_ind = *(reinterpret_cast<const int32_t *>(data + i)); | |||
| } else if (col_type == mindrecord::ColumnInt64) { | |||
| feature_ind = *(reinterpret_cast<const int64_t *>(data + i)); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); | |||
| #endif | |||
| } else { | |||
| for (int32_t ind : indices) { | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK( | |||
| graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor)); | |||
| RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); | |||
| (*feature_map)[edge_type].insert(ind); | |||
| if ((*default_feature)[ind] == nullptr) { | |||
| std::shared_ptr<Tensor> zero_tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); | |||
| RETURN_IF_NOT_OK(zero_tensor->Zero()); | |||
| (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); | |||
| } | |||
| } | |||
| if (feature_ind >= 0) indices->push_back(feature_ind); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -234,21 +248,19 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { | |||
| return Status::OK(); | |||
| } | |||
| void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, | |||
| DefaultNodeFeatureMap *default_node_feature_map, | |||
| DefaultEdgeFeatureMap *default_edge_feature_map) { | |||
| void GraphLoader::MergeFeatureMaps() { | |||
| for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | |||
| for (auto &m : n_feature_maps_[wkr_id]) { | |||
| for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); | |||
| for (auto &n : m.second) graph_impl_->node_feature_map_[m.first].insert(n); | |||
| } | |||
| for (auto &m : e_feature_maps_[wkr_id]) { | |||
| for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); | |||
| for (auto &n : m.second) graph_impl_->edge_feature_map_[m.first].insert(n); | |||
| } | |||
| for (auto &m : default_node_feature_maps_[wkr_id]) { | |||
| (*default_node_feature_map)[m.first] = m.second; | |||
| graph_impl_->default_node_feature_map_[m.first] = m.second; | |||
| } | |||
| for (auto &m : default_edge_feature_maps_[wkr_id]) { | |||
| (*default_edge_feature_map)[m.first] = m.second; | |||
| graph_impl_->default_edge_feature_map_[m.first] = m.second; | |||
| } | |||
| } | |||
| n_feature_maps_.clear(); | |||
| @@ -26,10 +26,13 @@ | |||
| #include "minddata/dataset/core/data_type.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/gnn/edge.h" | |||
| #include "minddata/dataset/engine/gnn/feature.h" | |||
| #include "minddata/dataset/engine/gnn/graph.h" | |||
| #include "minddata/dataset/engine/gnn/graph_feature_parser.h" | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||
| #endif | |||
| #include "minddata/dataset/engine/gnn/node.h" | |||
| #include "minddata/dataset/engine/gnn/edge.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/mindrecord/include/shard_reader.h" | |||
| namespace mindspore { | |||
| @@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureTy | |||
| using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | |||
| using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; | |||
| class GraphDataImpl; | |||
| // this class interfaces with the underlying storage format (mindrecord) | |||
| // it returns raw nodes and edges via GetNodesAndEdges | |||
| // it is then the responsibility of graph to construct itself based on the nodes and edges | |||
| // if needed, this class could become a base where each derived class handles a specific storage format | |||
| class GraphLoader { | |||
| public: | |||
| explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); | |||
| GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false); | |||
| ~GraphLoader() = default; | |||
| // Init mindrecord and load everything into memory multi-threaded | |||
| @@ -63,8 +68,7 @@ class GraphLoader { | |||
| // nodes and edges are added to map without any connection. That's because there nodes and edges are read in | |||
| // random order. src_node and dst_node in Edge are node_id only with -1 as type. | |||
| // features attached to each node and edge are expected to be filled correctly | |||
| Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, | |||
| DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); | |||
| Status GetNodesAndEdges(); | |||
| private: | |||
| // | |||
| @@ -92,29 +96,15 @@ class GraphLoader { | |||
| Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge, | |||
| EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); | |||
| // @param std::string key - column name | |||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||
| // @param mindrecord::json &jsn - contains raw data | |||
| // @param std::vector<int32_t> *ind - return value, list of feature index in int32_t | |||
| // @return Status - the status code | |||
| Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn, | |||
| std::vector<int32_t> *ind); | |||
| // @param std::string &key - column name | |||
| // @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord | |||
| // @param mindrecord::json &jsn - contains raw data | |||
| // @param std::shared_ptr<Tensor> *tensor - return value feature tensor | |||
| // @return Status - the status code | |||
| Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn, | |||
| std::shared_ptr<Tensor> *tensor); | |||
| // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 | |||
| void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); | |||
| void MergeFeatureMaps(); | |||
| GraphDataImpl *graph_impl_; | |||
| std::string mr_path_; | |||
| const int32_t num_workers_; | |||
| std::atomic_int row_id_; | |||
| std::string mr_path_; | |||
| std::unique_ptr<ShardReader> shard_reader_; | |||
| std::unique_ptr<GraphFeatureParser> graph_feature_parser_; | |||
| std::vector<std::deque<std::shared_ptr<Node>>> n_deques_; | |||
| std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_; | |||
| std::vector<NodeFeatureMap> n_feature_maps_; | |||
| @@ -0,0 +1,134 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/graph_shared_memory.h" | |||
| #include <string> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key) | |||
| : memory_size_(memory_size), | |||
| memory_key_(memory_key), | |||
| memory_ptr_(nullptr), | |||
| memory_offset_(0), | |||
| is_new_create_(false) { | |||
| std::stringstream stream; | |||
| stream << std::hex << memory_key_; | |||
| memory_key_str_ = stream.str(); | |||
| } | |||
| GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file) | |||
| : mr_file_(mr_file), | |||
| memory_size_(memory_size), | |||
| memory_key_(-1), | |||
| memory_ptr_(nullptr), | |||
| memory_offset_(0), | |||
| is_new_create_(false) {} | |||
| GraphSharedMemory::~GraphSharedMemory() { | |||
| if (is_new_create_) { | |||
| (void)DeleteSharedMemory(); | |||
| } | |||
| } | |||
| Status GraphSharedMemory::CreateSharedMemory() { | |||
| if (memory_key_ == -1) { | |||
| // ftok to generate unique key | |||
| memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_); | |||
| std::stringstream stream; | |||
| stream << std::hex << memory_key_; | |||
| memory_key_str_ = stream.str(); | |||
| } | |||
| int shmflg = (0666 | IPC_CREAT | IPC_EXCL); | |||
| Status s = SharedMemoryImpl(shmflg); | |||
| if (s.IsOk()) { | |||
| is_new_create_ = true; | |||
| MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_; | |||
| } else { | |||
| MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_; | |||
| shmflg = (0666 | IPC_CREAT); | |||
| s = SharedMemoryImpl(shmflg); | |||
| if (!s.IsOk()) { | |||
| RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GraphSharedMemory::GetSharedMemory() { | |||
| int shmflg = 0; | |||
| RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg)); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphSharedMemory::DeleteSharedMemory() { | |||
| int shmid = shmget(memory_key_, 0, 0); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_); | |||
| int result = shmctl(shmid, IPC_RMID, 0); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) { | |||
| // shmget returns an identifier in shmid | |||
| int shmid = shmget(memory_key_, memory_size_, shmflg); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_); | |||
| // shmat to attach to shared memory | |||
| auto data = shmat(shmid, reinterpret_cast<void *>(0), 0); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_); | |||
| memory_ptr_ = reinterpret_cast<uint8_t *>(data); | |||
| return Status::OK(); | |||
| } | |||
| Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid."); | |||
| std::lock_guard<std::mutex> lck(mutex_); | |||
| CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len), | |||
| "Insufficient shared memory space to insert data."); | |||
| if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) { | |||
| RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory."); | |||
| } | |||
| *offset = memory_offset_; | |||
| memory_offset_ += len; | |||
| return Status::OK(); | |||
| } | |||
| Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset, | |||
| "get_data_len is too large, beyond the space of shared memory."); | |||
| if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) { | |||
| RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ | |||
| #include <sys/ipc.h> | |||
| #include <sys/shm.h> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace gnn { | |||
| const int kGnnSharedMemoryId = 65; | |||
| class GraphSharedMemory { | |||
| public: | |||
| explicit GraphSharedMemory(int64_t memory_size, key_t memory_key); | |||
| explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file); | |||
| ~GraphSharedMemory(); | |||
| // @param uint8_t** shared_memory - shared memory address | |||
| // @return Status - the status code | |||
| Status CreateSharedMemory(); | |||
| // @param uint8_t** shared_memory - shared memory address | |||
| // @return Status - the status code | |||
| Status GetSharedMemory(); | |||
| Status DeleteSharedMemory(); | |||
| Status InsertData(const uint8_t *data, int64_t len, int64_t *offset); | |||
| Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len); | |||
| key_t memory_key() { return memory_key_; } | |||
| int64_t memory_size() { return memory_size_; } | |||
| private: | |||
| Status SharedMemoryImpl(const int &shmflg); | |||
| std::string mr_file_; | |||
| int64_t memory_size_; | |||
| key_t memory_key_; | |||
| std::string memory_key_str_; | |||
| uint8_t *memory_ptr_; | |||
| int64_t memory_offset_; | |||
| std::mutex mutex_; | |||
| bool is_new_create_; | |||
| }; | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/grpc_async_server.h" | |||
| #include <limits> | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {} | |||
| GrpcAsyncServer::~GrpcAsyncServer() { Stop(); } | |||
| Status GrpcAsyncServer::Run() { | |||
| std::string server_address = host_ + ":" + std::to_string(port_); | |||
| grpc::ServerBuilder builder; | |||
| // Default message size for gRPC is 4MB. Increase it to 2g-1 | |||
| builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max()); | |||
| builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); | |||
| int port_tcpip = 0; | |||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip); | |||
| RETURN_IF_NOT_OK(RegisterService(&builder)); | |||
| cq_ = builder.AddCompletionQueue(); | |||
| server_ = builder.BuildAndStart(); | |||
| if (server_) { | |||
| MS_LOG(INFO) << "Server listening on " << server_address; | |||
| } else { | |||
| std::string errMsg = "Fail to start server. "; | |||
| if (port_tcpip != port_) { | |||
| errMsg += "Unable to bind to address " + server_address + "."; | |||
| } | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GrpcAsyncServer::HandleRequest() { | |||
| bool success; | |||
| void *tag; | |||
| // We loop through the grpc queue. Each connection if successful | |||
| // will come back with our own tag which is an instance of CallData | |||
| // and we simply call its functor. But first we need to create these instances | |||
| // and inject them into the grpc queue. | |||
| RETURN_IF_NOT_OK(EnqueueRequest()); | |||
| while (cq_->Next(&tag, &success)) { | |||
| RETURN_IF_INTERRUPTED(); | |||
| if (success) { | |||
| RETURN_IF_NOT_OK(ProcessRequest(tag)); | |||
| } else { | |||
| MS_LOG(DEBUG) << "cq_->Next failed."; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| void GrpcAsyncServer::Stop() { | |||
| if (server_) { | |||
| server_->Shutdown(); | |||
| } | |||
| // Always shutdown the completion queue after the server. | |||
| if (cq_) { | |||
| cq_->Shutdown(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "grpcpp/grpcpp.h" | |||
| #include "grpcpp/impl/codegen/async_unary_call.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief Async server base class | |||
| class GrpcAsyncServer { | |||
| public: | |||
| explicit GrpcAsyncServer(const std::string &host, int32_t port); | |||
| virtual ~GrpcAsyncServer(); | |||
| /// \brief Brings up gRPC server | |||
| /// \return none | |||
| Status Run(); | |||
| /// \brief Entry function to handle async server request | |||
| Status HandleRequest(); | |||
| void Stop(); | |||
| virtual Status RegisterService(grpc::ServerBuilder *builder) = 0; | |||
| virtual Status EnqueueRequest() = 0; | |||
| virtual Status ProcessRequest(void *tag) = 0; | |||
| protected: | |||
| int32_t port_; | |||
| std::string host_; | |||
| std::unique_ptr<grpc::ServerCompletionQueue> cq_; | |||
| std::unique_ptr<grpc::Server> server_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ | |||
| @@ -44,6 +44,7 @@ Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) { | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| } // namespace gnn | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -20,10 +20,10 @@ | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/gnn/edge.h" | |||
| #include "minddata/dataset/engine/gnn/feature.h" | |||
| #include "minddata/dataset/engine/gnn/node.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -20,9 +20,9 @@ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/gnn/node.h" | |||
| #include "minddata/dataset/engine/gnn/feature.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -20,8 +20,8 @@ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/gnn/feature.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -0,0 +1,84 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/gnn/tensor_proto.h" | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include <unordered_map> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const std::unordered_map<DataTypePb, DataType::Type> g_pb2datatype_map{ | |||
| {DataTypePb::DE_PB_UNKNOWN, DataType::DE_UNKNOWN}, {DataTypePb::DE_PB_BOOL, DataType::DE_BOOL}, | |||
| {DataTypePb::DE_PB_INT8, DataType::DE_INT8}, {DataTypePb::DE_PB_UINT8, DataType::DE_UINT8}, | |||
| {DataTypePb::DE_PB_INT16, DataType::DE_INT16}, {DataTypePb::DE_PB_UINT16, DataType::DE_UINT16}, | |||
| {DataTypePb::DE_PB_INT32, DataType::DE_INT32}, {DataTypePb::DE_PB_UINT32, DataType::DE_UINT32}, | |||
| {DataTypePb::DE_PB_INT64, DataType::DE_INT64}, {DataTypePb::DE_PB_UINT64, DataType::DE_UINT64}, | |||
| {DataTypePb::DE_PB_FLOAT16, DataType::DE_FLOAT16}, {DataTypePb::DE_PB_FLOAT32, DataType::DE_FLOAT32}, | |||
| {DataTypePb::DE_PB_FLOAT64, DataType::DE_FLOAT64}, {DataTypePb::DE_PB_STRING, DataType::DE_STRING}, | |||
| }; | |||
| const std::unordered_map<DataType::Type, DataTypePb> g_datatype2pb_map{ | |||
| {DataType::DE_UNKNOWN, DataTypePb::DE_PB_UNKNOWN}, {DataType::DE_BOOL, DataTypePb::DE_PB_BOOL}, | |||
| {DataType::DE_INT8, DataTypePb::DE_PB_INT8}, {DataType::DE_UINT8, DataTypePb::DE_PB_UINT8}, | |||
| {DataType::DE_INT16, DataTypePb::DE_PB_INT16}, {DataType::DE_UINT16, DataTypePb::DE_PB_UINT16}, | |||
| {DataType::DE_INT32, DataTypePb::DE_PB_INT32}, {DataType::DE_UINT32, DataTypePb::DE_PB_UINT32}, | |||
| {DataType::DE_INT64, DataTypePb::DE_PB_INT64}, {DataType::DE_UINT64, DataTypePb::DE_PB_UINT64}, | |||
| {DataType::DE_FLOAT16, DataTypePb::DE_PB_FLOAT16}, {DataType::DE_FLOAT32, DataTypePb::DE_PB_FLOAT32}, | |||
| {DataType::DE_FLOAT64, DataTypePb::DE_PB_FLOAT64}, {DataType::DE_STRING, DataTypePb::DE_PB_STRING}, | |||
| }; | |||
| Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer"); | |||
| std::vector<dsize_t> shape = tensor->shape().AsVector(); | |||
| for (auto dim : shape) { | |||
| tensor_pb->add_dims(static_cast<google::protobuf::int64>(dim)); | |||
| } | |||
| auto iter = g_datatype2pb_map.find(tensor->type().value()); | |||
| if (iter == g_datatype2pb_map.end()) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid tensor type: " + tensor->type().ToString()); | |||
| } | |||
| tensor_pb->set_tensor_type(iter->second); | |||
| tensor_pb->set_data(tensor->GetBuffer(), tensor->SizeInBytes()); | |||
| return Status::OK(); | |||
| } | |||
| Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer"); | |||
| std::vector<dsize_t> shape; | |||
| shape.resize(tensor_pb->dims().size()); | |||
| std::transform(tensor_pb->dims().begin(), tensor_pb->dims().end(), shape.begin(), | |||
| [](const google::protobuf::int64 dim) { return static_cast<dsize_t>(dim); }); | |||
| auto iter = g_pb2datatype_map.find(tensor_pb->tensor_type()); | |||
| if (iter == g_pb2datatype_map.end()) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid Tensor_pb type: " + std::to_string(tensor_pb->tensor_type())); | |||
| } | |||
| DataType::Type type = iter->second; | |||
| std::shared_ptr<Tensor> tensor_out; | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(shape), DataType(type), | |||
| reinterpret_cast<const unsigned char *>(tensor_pb->data().data()), | |||
| tensor_pb->data().size(), &tensor_out)); | |||
| *tensor = std::move(tensor_out); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ | |||
| #include <deque> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "proto/gnn_tensor.pb.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb); | |||
| Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ | |||
| @@ -61,6 +61,7 @@ const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = { | |||
| class ShardColumn { | |||
| public: | |||
| explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true); | |||
| explicit ShardColumn(const json &schema_json, bool compress_integer = true); | |||
| ~ShardColumn() = default; | |||
| @@ -72,23 +73,29 @@ class ShardColumn { | |||
| std::vector<int64_t> *column_shape); | |||
| /// \brief compress blob | |||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob); | |||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size); | |||
| /// \brief check if blob compressed | |||
| bool CheckCompressBlob() const { return has_compress_blob_; } | |||
| /// \brief getter | |||
| uint64_t GetNumBlobColumn() const { return num_blob_column_; } | |||
| /// \brief getter | |||
| std::vector<std::string> GetColumnName() { return column_name_; } | |||
| /// \brief getter | |||
| std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; } | |||
| /// \brief getter | |||
| std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | |||
| /// \brief get column value from blob | |||
| MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| uint64_t *const n_bytes); | |||
| /// \brief get column type | |||
| std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name, | |||
| ColumnDataType *column_data_type, | |||
| uint64_t *column_data_type_size, | |||
| @@ -99,6 +106,9 @@ class ShardColumn { | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | |||
| private: | |||
| /// \brief intialization | |||
| void Init(const json &schema_json, bool compress_integer = true); | |||
| /// \brief get float value from json | |||
| template <typename T> | |||
| MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | |||
| @@ -65,6 +65,11 @@ class ShardHeader { | |||
| /// \return the Statistic | |||
| std::vector<std::shared_ptr<Statistics>> GetStatistics(); | |||
| /// \brief add the statistic and save it | |||
| /// \param[in] statistic info of slim size | |||
| /// \return null | |||
| int64_t GetSlimSizeStatistic(const json &slim_size_json); | |||
| /// \brief get the fields of the index | |||
| /// \return the fields of the index | |||
| std::vector<std::pair<uint64_t, std::string>> GetFields(); | |||
| @@ -114,10 +119,14 @@ class ShardHeader { | |||
| uint64_t GetPageSize() const { return page_size_; } | |||
| uint64_t GetCompressionSize() const { return compression_size_; } | |||
| void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } | |||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | |||
| void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; } | |||
| std::vector<std::string> SerializeHeader(); | |||
| MSRStatus PagesToFile(const std::string dump_file_name); | |||
| @@ -177,6 +186,7 @@ class ShardHeader { | |||
| uint32_t shard_count_; | |||
| uint64_t header_size_; | |||
| uint64_t page_size_; | |||
| uint64_t compression_size_; | |||
| std::shared_ptr<Index> index_; | |||
| std::vector<std::string> shard_addresses_; | |||
| @@ -209,6 +209,9 @@ class ShardReader { | |||
| /// \brief get all classes | |||
| MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); | |||
| /// \brief get the size of blob data | |||
| MSRStatus GetTotalBlobSize(int64_t *total_blob_size); | |||
| protected: | |||
| /// \brief sqlite call back function | |||
| static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | |||
| @@ -323,6 +326,7 @@ class ShardReader { | |||
| const std::string kThreadName = "THRD_ITER_"; // prefix of thread name | |||
| std::vector<std::thread> thread_set_; // thread list | |||
| int num_rows_; // number of rows | |||
| int64_t total_blob_size_; // total size of blob data | |||
| std::mutex mtx_delivery_; // locker for delivery | |||
| std::condition_variable cv_delivery_; // conditional variable for delivery | |||
| std::condition_variable cv_iterator_; // conditional variable for iterator | |||
| @@ -257,6 +257,7 @@ class ShardWriter { | |||
| std::mutex check_mutex_; // mutex for data check | |||
| std::atomic<bool> flag_{false}; | |||
| std::atomic<int64_t> compression_size_; | |||
| }; | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -43,6 +43,7 @@ ShardReader::ShardReader() { | |||
| page_size_ = 0; | |||
| header_size_ = 0; | |||
| num_rows_ = 0; | |||
| total_blob_size_ = 0; | |||
| num_padded_ = 0; | |||
| } | |||
| @@ -55,9 +56,11 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s | |||
| return {FAILED, {}}; | |||
| } | |||
| auto header = ret.second; | |||
| meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, | |||
| {"version", header["version"]}, {"index_fields", header["index_fields"]}, | |||
| {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; | |||
| uint64_t compression_size = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; | |||
| meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, | |||
| {"compression_size", compression_size}, {"version", header["version"]}, | |||
| {"index_fields", header["index_fields"]}, {"schema", header["schema"]}, | |||
| {"blob_fields", header["blob_fields"]}}; | |||
| return {SUCCESS, header["shard_addresses"]}; | |||
| } | |||
| @@ -145,6 +148,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa | |||
| for (const auto &rg : row_group_summary) { | |||
| num_rows_ += std::get<3>(rg); | |||
| } | |||
| auto disk_size = page_size_ * row_group_summary.size(); | |||
| auto compression_size = shard_header_->GetCompressionSize(); | |||
| total_blob_size_ = disk_size + compression_size; | |||
| MS_LOG(INFO) << "Blob data size, on disk: " << disk_size << " , addtional uncompression: " << compression_size | |||
| << " , Total: " << total_blob_size_; | |||
| MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; | |||
| @@ -272,6 +280,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar | |||
| return row_group_summary; | |||
| } | |||
| MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) { | |||
| *total_blob_size = total_blob_size_; | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, | |||
| std::shared_ptr<std::fstream> fs, | |||
| std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id, | |||
| @@ -28,11 +28,9 @@ using mindspore::MsLogLevel::INFO; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardWriter::ShardWriter() | |||
| : shard_count_(1), | |||
| header_size_(kDefaultHeaderSize), | |||
| page_size_(kDefaultPageSize), | |||
| row_count_(0), | |||
| schema_count_(1) {} | |||
| : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) { | |||
| compression_size_ = 0; | |||
| } | |||
| ShardWriter::~ShardWriter() { | |||
| for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) { | |||
| @@ -201,6 +199,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||
| if (ret == FAILED) { | |||
| return FAILED; | |||
| } | |||
| compression_size_ = shard_header_->GetCompressionSize(); | |||
| ret = Open(real_addresses, true); | |||
| if (ret == FAILED) { | |||
| MS_LOG(ERROR) << "Open file failed"; | |||
| @@ -614,7 +613,9 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json> | |||
| // compress blob | |||
| if (shard_column_->CheckCompressBlob()) { | |||
| for (auto &blob : blob_data) { | |||
| blob = shard_column_->CompressBlob(blob); | |||
| int64_t compression_bytes = 0; | |||
| blob = shard_column_->CompressBlob(blob, &compression_bytes); | |||
| compression_size_ += compression_bytes; | |||
| } | |||
| } | |||
| @@ -1177,6 +1178,11 @@ MSRStatus ShardWriter::WriteShardHeader() { | |||
| MS_LOG(ERROR) << "Shard header is null"; | |||
| return FAILED; | |||
| } | |||
| int64_t compression_temp = compression_size_; | |||
| uint64_t compression_size = compression_temp > 0 ? compression_temp : 0; | |||
| shard_header_->SetCompressionSize(compression_size); | |||
| auto shard_header = shard_header_->SerializeHeader(); | |||
| // Write header data to multi files | |||
| if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) { | |||
| @@ -24,7 +24,15 @@ namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) { | |||
| auto first_schema = shard_header->GetSchemas()[0]; | |||
| auto schema = first_schema->GetSchema()["schema"]; | |||
| json schema_json = first_schema->GetSchema(); | |||
| Init(schema_json, compress_integer); | |||
| } | |||
| ShardColumn::ShardColumn(const json &schema_json, bool compress_integer) { Init(schema_json, compress_integer); } | |||
| void ShardColumn::Init(const json &schema_json, bool compress_integer) { | |||
| auto schema = schema_json["schema"]; | |||
| auto blob_fields = schema_json["blob_fields"]; | |||
| bool has_integer_array = false; | |||
| for (json::iterator it = schema.begin(); it != schema.end(); ++it) { | |||
| @@ -52,8 +60,6 @@ ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool | |||
| column_name_id_[column_name_[i]] = i; | |||
| } | |||
| auto blob_fields = first_schema->GetBlobFields(); | |||
| for (const auto &field : blob_fields) { | |||
| blob_column_.push_back(field); | |||
| } | |||
| @@ -282,8 +288,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | |||
| return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; | |||
| } | |||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) { | |||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) { | |||
| // Skip if no compress columns | |||
| *compression_size = 0; | |||
| if (!CheckCompressBlob()) return blob; | |||
| std::vector<uint8_t> dst_blob; | |||
| @@ -295,7 +302,9 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) | |||
| // Compress and return is blob has 1 column only | |||
| if (num_blob_column_ == 1) { | |||
| return CompressInt(blob, int_type); | |||
| dst_blob = CompressInt(blob, int_type); | |||
| *compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size()); | |||
| return dst_blob; | |||
| } | |||
| // Just copy and continue if column dat type is not int32/int64 | |||
| @@ -319,6 +328,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) | |||
| i_src += kInt64Len + num_bytes; | |||
| } | |||
| MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; | |||
| *compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size()); | |||
| return dst_blob; | |||
| } | |||
| @@ -33,7 +33,9 @@ using mindspore::MsLogLevel::ERROR; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| std::atomic<bool> thread_status(false); | |||
| ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); } | |||
| ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) { | |||
| index_ = std::make_shared<Index>(); | |||
| } | |||
| MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | |||
| shard_count_ = headers.size(); | |||
| @@ -54,6 +56,7 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l | |||
| ParseShardAddress(header["shard_addresses"]); | |||
| header_size_ = header["header_size"].get<uint64_t>(); | |||
| page_size_ = header["page_size"].get<uint64_t>(); | |||
| compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; | |||
| } | |||
| if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { | |||
| return FAILED; | |||
| @@ -146,9 +149,12 @@ std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &fil | |||
| return {FAILED, json()}; | |||
| } | |||
| json raw_header = ret.second; | |||
| uint64_t compression_size = | |||
| raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0; | |||
| json header = {{"shard_addresses", raw_header["shard_addresses"]}, | |||
| {"header_size", raw_header["header_size"]}, | |||
| {"page_size", raw_header["page_size"]}, | |||
| {"compression_size", compression_size}, | |||
| {"index_fields", raw_header["index_fields"]}, | |||
| {"blob_fields", raw_header["schema"][0]["blob_fields"]}, | |||
| {"schema", raw_header["schema"][0]["schema"]}, | |||
| @@ -343,6 +349,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() { | |||
| s += "\"index_fields\":" + index + ","; | |||
| s += "\"page\":" + pages[shardId] + ","; | |||
| s += "\"page_size\":" + std::to_string(page_size_) + ","; | |||
| s += "\"compression_size\":" + std::to_string(compression_size_) + ","; | |||
| s += "\"schema\":" + schema + ","; | |||
| s += "\"shard_addresses\":" + address + ","; | |||
| s += "\"shard_id\":" + std::to_string(shardId) + ","; | |||
| @@ -3083,20 +3083,22 @@ def _cpp_sampler_fn(sampler, dataset): | |||
| yield tuple([np.array(x, copy=False) for x in val]) | |||
| def _cpp_sampler_fn_mp(sampler, dataset, num_worker): | |||
| def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process): | |||
| """ | |||
| Multiprocessing generator function wrapper for mappable dataset with cpp sampler. | |||
| """ | |||
| indices = sampler.get_indices() | |||
| return _sampler_fn_mp(indices, dataset, num_worker) | |||
| sample_fn = SamplerFn(dataset, num_worker, multi_process) | |||
| return sample_fn.process(indices) | |||
| def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): | |||
| def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process): | |||
| """ | |||
| Multiprocessing generator function wrapper for mappable dataset with python sampler. | |||
| """ | |||
| indices = _fetch_py_sampler_indices(sampler, num_samples) | |||
| return _sampler_fn_mp(indices, dataset, num_worker) | |||
| sample_fn = SamplerFn(dataset, num_worker, multi_process) | |||
| return sample_fn.process(indices) | |||
| def _fetch_py_sampler_indices(sampler, num_samples): | |||
| @@ -3130,63 +3132,92 @@ def _fill_worker_indices(workers, indices, idx): | |||
| return idx | |||
| def _sampler_fn_mp(indices, dataset, num_worker): | |||
| class SamplerFn: | |||
| """ | |||
| Multiprocessing generator function wrapper master process. | |||
| Multiprocessing or multithread generator function wrapper master process. | |||
| """ | |||
| workers = [] | |||
| # Event for end of epoch | |||
| eoe = multiprocessing.Event() | |||
| # Create workers | |||
| for _ in range(num_worker): | |||
| worker = _GeneratorWorker(dataset, eoe) | |||
| worker.daemon = True | |||
| workers.append(worker) | |||
| # Fill initial index queues | |||
| idx_cursor = 0 | |||
| idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) | |||
| # Start all workers | |||
| for w in workers: | |||
| w.start() | |||
| # Fetch results | |||
| for i in range(len(indices)): | |||
| # Fetch result and put index | |||
| try: | |||
| result = workers[i % num_worker].get() | |||
| except queue.Empty: | |||
| raise Exception("Generator worker process timeout") | |||
| except KeyboardInterrupt: | |||
| for w in workers: | |||
| w.terminate() | |||
| def __init__(self, dataset, num_worker, multi_process): | |||
| self.workers = [] | |||
| self.num_worker = num_worker | |||
| self.multi_process = multi_process | |||
| # Event for end of epoch | |||
| if multi_process is True: | |||
| self.eoe = multiprocessing.Event() | |||
| self.eof = multiprocessing.Event() | |||
| else: | |||
| self.eoe = threading.Event() | |||
| self.eof = threading.Event() | |||
| # Create workers | |||
| for _ in range(num_worker): | |||
| if multi_process is True: | |||
| worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) | |||
| else: | |||
| worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) | |||
| worker.daemon = True | |||
| self.workers.append(worker) | |||
| def process(self, indices): | |||
| """ | |||
| The main process, start the child process or child thread, and fill the index queue, | |||
| get the result from the result and return. | |||
| """ | |||
| # Fill initial index queues | |||
| idx_cursor = 0 | |||
| idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) | |||
| # Start all workers | |||
| for w in self.workers: | |||
| w.start() | |||
| # Fetch results | |||
| for i in range(len(indices)): | |||
| # Fetch result and put index | |||
| try: | |||
| result = self.workers[i % self.num_worker].get() | |||
| except queue.Empty: | |||
| raise Exception("Generator worker process timeout") | |||
| except KeyboardInterrupt: | |||
| self.eof.set() | |||
| for w in self.workers: | |||
| w.terminate() | |||
| w.join() | |||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||
| if idx_cursor < len(indices): | |||
| idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) | |||
| # Set eoe event once all indices are sent | |||
| if idx_cursor == len(indices) and not self.eoe.is_set(): | |||
| self.eoe.set() | |||
| yield tuple([np.array(x, copy=False) for x in result]) | |||
| def __del__(self): | |||
| self.eoe.set() | |||
| self.eof.set() | |||
| if self.multi_process is False: | |||
| for w in self.workers: | |||
| w.join() | |||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||
| if idx_cursor < len(indices): | |||
| idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) | |||
| # Set eoe event once all indices are sent | |||
| if idx_cursor == len(indices) and not eoe.is_set(): | |||
| eoe.set() | |||
| yield tuple([np.array(x, copy=False) for x in result]) | |||
| def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): | |||
| def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): | |||
| """ | |||
| Multiprocessing generator worker process loop. | |||
| Multiprocessing or multithread generator worker process loop. | |||
| """ | |||
| while True: | |||
| # Fetch index, block | |||
| try: | |||
| idx = idx_queue.get() | |||
| idx = idx_queue.get(timeout=10) | |||
| except KeyboardInterrupt: | |||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||
| except queue.Empty: | |||
| if eof.is_set() or eoe.is_set(): | |||
| raise Exception("Generator worker receives queue.Empty") | |||
| continue | |||
| if idx is None: | |||
| # When the queue is out of scope from master process, a None item can be fetched from the queue. | |||
| # Upon receiving None, worker process should check if EOE is set. | |||
| assert eoe.is_set(), "" | |||
| return | |||
| if eof.is_set(): | |||
| return | |||
| # Fetch data, any exception from __getitem__ will terminate worker and timeout master process | |||
| result = dataset[idx] | |||
| # Send data, block | |||
| @@ -3195,17 +3226,42 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): | |||
| except KeyboardInterrupt: | |||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||
| del result, idx | |||
| if eoe.is_set() and idx_queue.empty(): | |||
| return | |||
| class _GeneratorWorker(multiprocessing.Process): | |||
| class _GeneratorWorkerMt(threading.Thread): | |||
| """ | |||
| Worker process for multithread Generator. | |||
| """ | |||
| def __init__(self, dataset, eoe, eof): | |||
| self.idx_queue = queue.Queue(16) | |||
| self.res_queue = queue.Queue(16) | |||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||
| def put(self, item): | |||
| """ | |||
| Put function for worker index queue. Never block. Raise queue.Full on failure. | |||
| """ | |||
| self.idx_queue.put_nowait(item) | |||
| def get(self): | |||
| """ | |||
| Get function for worker result queue. Block with timeout. | |||
| """ | |||
| return self.res_queue.get(timeout=10) | |||
| class _GeneratorWorkerMp(multiprocessing.Process): | |||
| """ | |||
| Worker process for multiprocess Generator. | |||
| """ | |||
| def __init__(self, dataset, eoe): | |||
| def __init__(self, dataset, eoe, eof): | |||
| self.idx_queue = multiprocessing.Queue(16) | |||
| self.res_queue = multiprocessing.Queue(16) | |||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) | |||
| super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) | |||
| def put(self, item): | |||
| """ | |||
| @@ -3217,7 +3273,7 @@ class _GeneratorWorker(multiprocessing.Process): | |||
| """ | |||
| Get function for worker result queue. Block with timeout. | |||
| """ | |||
| return self.res_queue.get() | |||
| return self.res_queue.get(timeout=10) | |||
| def __del__(self): | |||
| self.terminate() | |||
| @@ -3280,6 +3336,8 @@ class GeneratorDataset(MappableDataset): | |||
| When this argument is specified, 'num_samples' will not effect. Random accessible input is required. | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only | |||
| when num_shards is also specified. Random accessible input is required. | |||
| python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This | |||
| option could be beneficial if the python operation is computational heavy (default=True). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -3316,12 +3374,14 @@ class GeneratorDataset(MappableDataset): | |||
| @check_generatordataset | |||
| def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, | |||
| num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): | |||
| num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, | |||
| python_multiprocessing=True): | |||
| super().__init__(num_parallel_workers) | |||
| self.source = source | |||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | |||
| self.num_samples = num_samples | |||
| self.num_shards = num_shards | |||
| self.python_multiprocessing = python_multiprocessing | |||
| if column_names is not None and not isinstance(column_names, list): | |||
| column_names = [column_names] | |||
| @@ -3403,12 +3463,16 @@ class GeneratorDataset(MappableDataset): | |||
| sampler_instance.set_num_rows(len(self.source)) | |||
| sampler_instance.initialize() | |||
| if new_op.num_parallel_workers > 1: | |||
| new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers)) | |||
| new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, | |||
| new_op.num_parallel_workers, | |||
| self.python_multiprocessing)) | |||
| else: | |||
| new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) | |||
| else: | |||
| if new_op.num_parallel_workers > 1: | |||
| new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers)) | |||
| new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, | |||
| new_op.num_parallel_workers, | |||
| self.python_multiprocessing)) | |||
| else: | |||
| new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) | |||
| else: | |||
| @@ -16,8 +16,11 @@ | |||
| graphdata.py supports loading graph dataset for GNN network training, | |||
| and provides operations related to graph data. | |||
| """ | |||
| import atexit | |||
| import time | |||
| import numpy as np | |||
| from mindspore._c_dataengine import Graph | |||
| from mindspore._c_dataengine import GraphDataClient | |||
| from mindspore._c_dataengine import GraphDataServer | |||
| from mindspore._c_dataengine import Tensor | |||
| from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ | |||
| @@ -34,14 +37,52 @@ class GraphData: | |||
| dataset_file (str): One of file names in dataset. | |||
| num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel | |||
| (default=None). | |||
| working_mode (str, optional): Set working mode, now support 'local'/'client'/'server' (default='local'). | |||
| - 'local', used in non-distributed training scenarios. | |||
| - 'client', used in distributed training scenarios, the client does not load data, | |||
| but obtains data from the server. | |||
| - 'server', used in distributed training scenarios, the server loads the data | |||
| and is available to the client. | |||
| hostname (str, optional): Valid when working_mode is set to 'client' or 'server', | |||
| set the hostname of the graph data server (default='127.0.0.1'). | |||
| port (int, optional): Valid when working_mode is set to 'client' or 'server', | |||
| set the port of the graph data server, the range is 1024-65535 (default=50051). | |||
| num_client (int, optional): Valid when working_mode is set to 'server', | |||
| set the number of clients expected to connect, and the server will allocate corresponding | |||
| resources according to this parameter (default=1). | |||
| auto_shutdown (bool, optional): Valid when working_mode is set to 'server', | |||
| Control when all clients have connected and no client connected to the server, | |||
| automatically exit the server (default=True). | |||
| """ | |||
| @check_gnn_graphdata | |||
| def __init__(self, dataset_file, num_parallel_workers=None): | |||
| def __init__(self, dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, | |||
| num_client=1, auto_shutdown=True): | |||
| self._dataset_file = dataset_file | |||
| self._working_mode = working_mode | |||
| if num_parallel_workers is None: | |||
| num_parallel_workers = 1 | |||
| self._graph = Graph(dataset_file, num_parallel_workers) | |||
| def stop(): | |||
| self._graph_data.stop() | |||
| atexit.register(stop) | |||
| if working_mode in ['local', 'client']: | |||
| self._graph_data = GraphDataClient(dataset_file, num_parallel_workers, working_mode, hostname, port) | |||
| if working_mode == 'server': | |||
| self._graph_data = GraphDataServer( | |||
| dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) | |||
| try: | |||
| while self._graph_data.is_stoped() is not True: | |||
| time.sleep(1) | |||
| except KeyboardInterrupt: | |||
| # self._graph_data.stop() | |||
| raise Exception("Graph data server receives KeyboardInterrupt") | |||
| @check_gnn_get_all_nodes | |||
| def get_all_nodes(self, node_type): | |||
| @@ -62,7 +103,9 @@ class GraphData: | |||
| Raises: | |||
| TypeError: If `node_type` is not integer. | |||
| """ | |||
| return self._graph.get_all_nodes(node_type).as_array() | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.get_all_nodes(node_type).as_array() | |||
| @check_gnn_get_all_edges | |||
| def get_all_edges(self, edge_type): | |||
| @@ -83,7 +126,9 @@ class GraphData: | |||
| Raises: | |||
| TypeError: If `edge_type` is not integer. | |||
| """ | |||
| return self._graph.get_all_edges(edge_type).as_array() | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.get_all_edges(edge_type).as_array() | |||
| @check_gnn_get_nodes_from_edges | |||
| def get_nodes_from_edges(self, edge_list): | |||
| @@ -99,7 +144,9 @@ class GraphData: | |||
| Raises: | |||
| TypeError: If `edge_list` is not list or ndarray. | |||
| """ | |||
| return self._graph.get_nodes_from_edges(edge_list).as_array() | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.get_nodes_from_edges(edge_list).as_array() | |||
| @check_gnn_get_all_neighbors | |||
| def get_all_neighbors(self, node_list, neighbor_type): | |||
| @@ -123,7 +170,9 @@ class GraphData: | |||
| TypeError: If `node_list` is not list or ndarray. | |||
| TypeError: If `neighbor_type` is not integer. | |||
| """ | |||
| return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() | |||
| @check_gnn_get_sampled_neighbors | |||
| def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): | |||
| @@ -155,7 +204,9 @@ class GraphData: | |||
| TypeError: If `neighbor_nums` is not list or ndarray. | |||
| TypeError: If `neighbor_types` is not list or ndarray. | |||
| """ | |||
| return self._graph.get_sampled_neighbors( | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.get_sampled_neighbors( | |||
| node_list, neighbor_nums, neighbor_types).as_array() | |||
| @check_gnn_get_neg_sampled_neighbors | |||
| @@ -182,7 +233,9 @@ class GraphData: | |||
| TypeError: If `neg_neighbor_num` is not integer. | |||
| TypeError: If `neg_neighbor_type` is not integer. | |||
| """ | |||
| return self._graph.get_neg_sampled_neighbors( | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.get_neg_sampled_neighbors( | |||
| node_list, neg_neighbor_num, neg_neighbor_type).as_array() | |||
| @check_gnn_get_node_feature | |||
| @@ -207,10 +260,12 @@ class GraphData: | |||
| TypeError: If `node_list` is not list or ndarray. | |||
| TypeError: If `feature_types` is not list or ndarray. | |||
| """ | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| if isinstance(node_list, list): | |||
| node_list = np.array(node_list, dtype=np.int32) | |||
| return [ | |||
| t.as_array() for t in self._graph.get_node_feature( | |||
| t.as_array() for t in self._graph_data.get_node_feature( | |||
| Tensor(node_list), | |||
| feature_types)] | |||
| @@ -236,10 +291,12 @@ class GraphData: | |||
| TypeError: If `edge_list` is not list or ndarray. | |||
| TypeError: If `feature_types` is not list or ndarray. | |||
| """ | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| if isinstance(edge_list, list): | |||
| edge_list = np.array(edge_list, dtype=np.int32) | |||
| return [ | |||
| t.as_array() for t in self._graph.get_edge_feature( | |||
| t.as_array() for t in self._graph_data.get_edge_feature( | |||
| Tensor(edge_list), | |||
| feature_types)] | |||
| @@ -252,7 +309,9 @@ class GraphData: | |||
| dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, | |||
| node_feature_type and edge_feature_type. | |||
| """ | |||
| return self._graph.graph_info() | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.graph_info() | |||
| @check_gnn_random_walk | |||
| def random_walk( | |||
| @@ -285,5 +344,7 @@ class GraphData: | |||
| TypeError: If `target_nodes` is not list or ndarray. | |||
| TypeError: If `meta_path` is not list or ndarray. | |||
| """ | |||
| return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, | |||
| default_node).as_array() | |||
| if self._working_mode == 'server': | |||
| raise Exception("This method is not supported when working mode is server") | |||
| return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param, | |||
| default_node).as_array() | |||
| @@ -18,6 +18,7 @@ Built-in validators. | |||
| """ | |||
| import inspect as ins | |||
| import os | |||
| import re | |||
| from functools import wraps | |||
| import numpy as np | |||
| @@ -912,16 +913,36 @@ def check_split(method): | |||
| return new_method | |||
| def check_hostname(hostname): | |||
| if len(hostname) > 255: | |||
| return False | |||
| if hostname[-1] == ".": | |||
| hostname = hostname[:-1] # strip exactly one dot from the right, if present | |||
| allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE) | |||
| return all(allowed.match(x) for x in hostname.split(".")) | |||
| def check_gnn_graphdata(method): | |||
| """check the input arguments of graphdata.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) | |||
| [dataset_file, num_parallel_workers, working_mode, hostname, | |||
| port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs) | |||
| check_file(dataset_file) | |||
| if num_parallel_workers is not None: | |||
| check_num_parallel_workers(num_parallel_workers) | |||
| type_check(hostname, (str,), "hostname") | |||
| if check_hostname(hostname) is False: | |||
| raise ValueError("The hostname is illegal") | |||
| type_check(working_mode, (str,), "working_mode") | |||
| if working_mode not in {'local', 'client', 'server'}: | |||
| raise ValueError("Invalid working mode") | |||
| type_check(port, (int,), "port") | |||
| check_value(port, (1024, 65535), "port") | |||
| type_check(num_client, (int,), "num_client") | |||
| check_value(num_client, (1, 255), "num_client") | |||
| type_check(auto_shutdown, (bool,), "auto_shutdown") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -15,6 +15,7 @@ | |||
| """ | |||
| User-defined API for MindRecord GNN writer. | |||
| """ | |||
| import numpy as np | |||
| social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], | |||
| [348, 336], [348, 337], [348, 338], [348, 340], [348, 341], | |||
| [348, 342], [348, 343], [348, 344], [348, 345], [348, 346], | |||
| @@ -29,7 +30,7 @@ social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], | |||
| [355, 352], [353, 350], [352, 349], [351, 349], [350, 349]] | |||
| # profile: (num_features, feature_data_types, feature_shapes) | |||
| node_profile = (0, [], []) | |||
| node_profile = (2, ["int64", "int32"], [[-1], [-1]]) | |||
| edge_profile = (0, [], []) | |||
| @@ -51,7 +52,9 @@ def yield_nodes(task_id=0): | |||
| node_list.sort() | |||
| print(node_list) | |||
| for node_id in node_list: | |||
| node = {'id': node_id, 'type': 1} | |||
| node = {'id': node_id, 'type': 1, | |||
| 'feature_1': np.ones((5,), dtype=np.int64), | |||
| 'feature_2': np.ones((10,), dtype=np.int32)} | |||
| yield node | |||
| @@ -22,6 +22,7 @@ | |||
| #include "gtest/gtest.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/gnn/node.h" | |||
| #include "minddata/dataset/engine/gnn/graph_data_impl.h" | |||
| #include "minddata/dataset/engine/gnn/graph_loader.h" | |||
| using namespace mindspore::dataset; | |||
| @@ -39,30 +40,9 @@ class MindDataTestGNNGraph : public UT::Common { | |||
| MindDataTestGNNGraph() = default; | |||
| }; | |||
| TEST_F(MindDataTestGNNGraph, TestGraphLoader) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| GraphLoader gl(path, 4); | |||
| EXPECT_TRUE(gl.InitAndLoad().IsOk()); | |||
| NodeIdMap n_id_map; | |||
| EdgeIdMap e_id_map; | |||
| NodeTypeMap n_type_map; | |||
| EdgeTypeMap e_type_map; | |||
| NodeFeatureMap n_feature_map; | |||
| EdgeFeatureMap e_feature_map; | |||
| DefaultNodeFeatureMap default_node_feature_map; | |||
| DefaultEdgeFeatureMap default_edge_feature_map; | |||
| EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map, | |||
| &default_node_feature_map, &default_edge_feature_map) | |||
| .IsOk()); | |||
| EXPECT_EQ(n_id_map.size(), 20); | |||
| EXPECT_EQ(e_id_map.size(), 40); | |||
| EXPECT_EQ(n_type_map[2].size(), 10); | |||
| EXPECT_EQ(n_type_map[1].size(), 10); | |||
| } | |||
| TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| Graph graph(path, 1); | |||
| GraphDataImpl graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| @@ -103,7 +83,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { | |||
| TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| Graph graph(path, 1); | |||
| GraphDataImpl graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| @@ -194,7 +174,7 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { | |||
| TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | |||
| std::string path = "data/mindrecord/testGraphData/testdata"; | |||
| Graph graph(path, 1); | |||
| GraphDataImpl graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| @@ -237,7 +217,7 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { | |||
| TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | |||
| std::string path = "data/mindrecord/testGraphData/sns"; | |||
| Graph graph(path, 1); | |||
| GraphDataImpl graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| @@ -263,7 +243,7 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) { | |||
| TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { | |||
| std::string path = "data/mindrecord/testGraphData/sns"; | |||
| Graph graph(path, 1); | |||
| GraphDataImpl graph(path, 1); | |||
| Status s = graph.Init(); | |||
| EXPECT_TRUE(s.IsOk()); | |||
| @@ -0,0 +1,125 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import random | |||
| import time | |||
| from multiprocessing import Process | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| DATASET_FILE = "../data/mindrecord/testGraphData/testdata" | |||
| def graphdata_startserver(): | |||
| """ | |||
| start graphdata server | |||
| """ | |||
| logger.info('test start server.\n') | |||
| ds.GraphData(DATASET_FILE, 1, 'server') | |||
| class RandomBatchedSampler(ds.Sampler): | |||
| # RandomBatchedSampler generate random sequence without replacement in a batched manner | |||
| def __init__(self, index_range, num_edges_per_sample): | |||
| super().__init__() | |||
| self.index_range = index_range | |||
| self.num_edges_per_sample = num_edges_per_sample | |||
| def __iter__(self): | |||
| indices = [i+1 for i in range(self.index_range)] | |||
| # Reset random seed here if necessary | |||
| # random.seed(0) | |||
| random.shuffle(indices) | |||
| for i in range(0, self.index_range, self.num_edges_per_sample): | |||
| # Drop reminder | |||
| if i + self.num_edges_per_sample <= self.index_range: | |||
| yield indices[i: i + self.num_edges_per_sample] | |||
| class GNNGraphDataset(): | |||
| def __init__(self, g, batch_num): | |||
| self.g = g | |||
| self.batch_num = batch_num | |||
| def __len__(self): | |||
| # Total sample size of GNN dataset | |||
| # In this case, the size should be total_num_edges/num_edges_per_sample | |||
| return self.g.graph_info()['edge_num'][0] // self.batch_num | |||
| def __getitem__(self, index): | |||
| # index will be a list of indices yielded from RandomBatchedSampler | |||
| # Fetch edges/nodes/samples/features based on indices | |||
| nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) | |||
| nodes = nodes[:, 0] | |||
| neg_nodes = self.g.get_neg_sampled_neighbors( | |||
| node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) | |||
| nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ | |||
| 2, 2], neighbor_types=[2, 1]) | |||
| neg_nodes_neighbors = self.g.get_sampled_neighbors( | |||
| node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) | |||
| nodes_neighbors_features = self.g.get_node_feature( | |||
| node_list=nodes_neighbors, feature_types=[2, 3]) | |||
| neg_neighbors_features = self.g.get_node_feature( | |||
| node_list=neg_nodes_neighbors, feature_types=[2, 3]) | |||
| return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] | |||
| def test_graphdata_distributed(): | |||
| """ | |||
| Test distributed | |||
| """ | |||
| logger.info('test distributed.\n') | |||
| p1 = Process(target=graphdata_startserver) | |||
| p1.start() | |||
| time.sleep(2) | |||
| g = ds.GraphData(DATASET_FILE, 1, 'client') | |||
| nodes = g.get_all_nodes(1) | |||
| assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] | |||
| row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) | |||
| assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0], | |||
| [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], | |||
| [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]] | |||
| assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4] | |||
| edges = g.get_all_edges(0) | |||
| assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, | |||
| 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] | |||
| features = g.get_edge_feature(edges, [1, 2]) | |||
| assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, | |||
| 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] | |||
| batch_num = 2 | |||
| edge_num = g.graph_info()['edge_num'][0] | |||
| out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] | |||
| dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, | |||
| sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4, | |||
| python_multiprocessing=False) | |||
| dataset = dataset.repeat(2) | |||
| itr = dataset.create_dict_iterator() | |||
| i = 0 | |||
| for data in itr: | |||
| assert data['neighbors'].shape == (2, 7) | |||
| assert data['neg_neighbors'].shape == (6, 7) | |||
| assert data['neighbors_features'].shape == (2, 7) | |||
| assert data['neg_neighbors_features'].shape == (6, 7) | |||
| i += 1 | |||
| assert i == 40 | |||
| if __name__ == '__main__': | |||
| test_graphdata_distributed() | |||