You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_loader.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <future>
  17. #include <tuple>
  18. #include <utility>
  19. #include "dataset/engine/gnn/graph_loader.h"
  20. #include "mindspore/ccsrc/mindrecord/include/shard_error.h"
  21. #include "dataset/engine/gnn/local_edge.h"
  22. #include "dataset/engine/gnn/local_node.h"
  23. #include "dataset/util/task_manager.h"
  24. using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
  25. namespace mindspore {
  26. namespace dataset {
  27. namespace gnn {
  28. using mindrecord::MSRStatus;
  29. GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
  30. : mr_path_(mr_filepath),
  31. num_workers_(num_workers),
  32. row_id_(0),
  33. shard_reader_(nullptr),
  34. keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
  35. Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
  36. EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
  37. EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) {
  38. for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
  39. while (dq.empty() == false) {
  40. std::shared_ptr<Node> node_ptr = dq.front();
  41. n_id_map->insert({node_ptr->id(), node_ptr});
  42. (*n_type_map)[node_ptr->type()].push_back(node_ptr->id());
  43. dq.pop_front();
  44. }
  45. }
  46. for (std::deque<std::shared_ptr<Edge>> &dq : e_deques_) {
  47. while (dq.empty() == false) {
  48. std::shared_ptr<Edge> edge_ptr = dq.front();
  49. std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
  50. RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
  51. auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
  52. CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
  53. CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
  54. RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
  55. RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
  56. e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
  57. (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id());
  58. dq.pop_front();
  59. }
  60. }
  61. for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
  62. for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
  63. MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map);
  64. return Status::OK();
  65. }
  66. Status GraphLoader::InitAndLoad() {
  67. CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n");
  68. CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n");
  69. n_deques_.resize(num_workers_);
  70. e_deques_.resize(num_workers_);
  71. n_feature_maps_.resize(num_workers_);
  72. e_feature_maps_.resize(num_workers_);
  73. default_feature_maps_.resize(num_workers_);
  74. TaskGroup vg;
  75. shard_reader_ = std::make_unique<ShardReader>();
  76. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
  77. "Fail to open" + mr_path_);
  78. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
  79. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
  80. mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"];
  81. for (const std::string &key : keys_) {
  82. if (schema.find(key) == schema.end()) {
  83. RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
  84. }
  85. }
  86. // launching worker threads
  87. for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
  88. RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
  89. }
  90. // wait for threads to finish and check its return code
  91. vg.join_all(Task::WaitFlag::kBlocking);
  92. RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny());
  93. return Status::OK();
  94. }
  95. Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
  96. std::shared_ptr<Node> *node, NodeFeatureMap *feature_map,
  97. DefaultFeatureMap *default_feature) {
  98. NodeIdType node_id = col_jsn["first_id"];
  99. NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
  100. (*node) = std::make_shared<LocalNode>(node_id, node_type);
  101. std::vector<int32_t> indices;
  102. RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices));
  103. for (int32_t ind : indices) {
  104. std::shared_ptr<Tensor> tensor;
  105. RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
  106. RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
  107. (*feature_map)[node_type].insert(ind);
  108. if ((*default_feature)[ind] == nullptr) {
  109. std::shared_ptr<Tensor> zero_tensor;
  110. RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
  111. RETURN_IF_NOT_OK(zero_tensor->Zero());
  112. (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
  113. }
  114. }
  115. return Status::OK();
  116. }
  117. Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
  118. std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map,
  119. DefaultFeatureMap *default_feature) {
  120. EdgeIdType edge_id = col_jsn["first_id"];
  121. EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
  122. NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
  123. std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1);
  124. std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
  125. (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
  126. std::vector<int32_t> indices;
  127. RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices));
  128. for (int32_t ind : indices) {
  129. std::shared_ptr<Tensor> tensor;
  130. RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
  131. RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
  132. (*feature_map)[edge_type].insert(ind);
  133. if ((*default_feature)[ind] == nullptr) {
  134. std::shared_ptr<Tensor> zero_tensor;
  135. RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
  136. RETURN_IF_NOT_OK(zero_tensor->Zero());
  137. (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
  138. }
  139. }
  140. return Status::OK();
  141. }
  142. Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
  143. const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
  144. const unsigned char *data = nullptr;
  145. std::unique_ptr<unsigned char[]> data_ptr;
  146. uint64_t n_bytes = 0, col_type_size = 1;
  147. mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
  148. std::vector<int64_t> column_shape;
  149. MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
  150. key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
  151. CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
  152. if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
  153. RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible,
  154. std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
  155. std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data));
  156. return Status::OK();
  157. }
  158. Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
  159. const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
  160. const unsigned char *data = nullptr;
  161. std::unique_ptr<unsigned char[]> data_ptr;
  162. uint64_t n_bytes = 0, col_type_size = 1;
  163. mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
  164. std::vector<int64_t> column_shape;
  165. MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
  166. key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
  167. CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
  168. if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
  169. for (int i = 0; i < n_bytes; i += col_type_size) {
  170. int32_t feature_ind = -1;
  171. if (col_type == mindrecord::ColumnInt32) {
  172. feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
  173. } else if (col_type == mindrecord::ColumnInt64) {
  174. feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
  175. } else {
  176. RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
  177. }
  178. if (feature_ind >= 0) indices->push_back(feature_ind);
  179. }
  180. return Status::OK();
  181. }
  182. Status GraphLoader::WorkerEntry(int32_t worker_id) {
  183. // Handshake
  184. TaskManager::FindMe()->Post();
  185. auto ret = shard_reader_->GetNextById(row_id_++, worker_id);
  186. ShardTuple rows = ret.second;
  187. while (rows.empty() == false) {
  188. RETURN_IF_INTERRUPTED();
  189. for (const auto &tupled_row : rows) {
  190. std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
  191. mindrecord::json col_jsn = std::get<1>(tupled_row);
  192. std::string attr = col_jsn["attribute"];
  193. if (attr == "n") {
  194. std::shared_ptr<Node> node_ptr;
  195. RETURN_IF_NOT_OK(
  196. LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
  197. n_deques_[worker_id].emplace_back(node_ptr);
  198. } else if (attr == "e") {
  199. std::shared_ptr<Edge> edge_ptr;
  200. RETURN_IF_NOT_OK(
  201. LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
  202. e_deques_[worker_id].emplace_back(edge_ptr);
  203. } else {
  204. MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node.";
  205. }
  206. }
  207. auto rc = shard_reader_->GetNextById(row_id_++, worker_id);
  208. rows = rc.second;
  209. }
  210. return Status::OK();
  211. }
  212. void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
  213. DefaultFeatureMap *default_feature_map) {
  214. for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
  215. for (auto &m : n_feature_maps_[wkr_id]) {
  216. for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
  217. }
  218. for (auto &m : e_feature_maps_[wkr_id]) {
  219. for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
  220. }
  221. for (auto &m : default_feature_maps_[wkr_id]) {
  222. (*default_feature_map)[m.first] = m.second;
  223. }
  224. }
  225. n_feature_maps_.clear();
  226. e_feature_maps_.clear();
  227. }
  228. } // namespace gnn
  229. } // namespace dataset
  230. } // namespace mindspore