You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_loader.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <future>
  17. #include <tuple>
  18. #include <utility>
  19. #include "dataset/engine/gnn/graph_loader.h"
  20. #include "mindspore/ccsrc/mindrecord/include/shard_error.h"
  21. #include "dataset/engine/gnn/local_edge.h"
  22. #include "dataset/engine/gnn/local_node.h"
  23. #include "dataset/util/task_manager.h"
  24. using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
  25. namespace mindspore {
  26. namespace dataset {
  27. namespace gnn {
  28. using mindrecord::MSRStatus;
  29. GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
  30. : mr_path_(mr_filepath),
  31. num_workers_(num_workers),
  32. row_id_(0),
  33. keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
  34. Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
  35. EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
  36. EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) {
  37. for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
  38. while (dq.empty() == false) {
  39. std::shared_ptr<Node> node_ptr = dq.front();
  40. n_id_map->insert({node_ptr->id(), node_ptr});
  41. (*n_type_map)[node_ptr->type()].push_back(node_ptr->id());
  42. dq.pop_front();
  43. }
  44. }
  45. for (std::deque<std::shared_ptr<Edge>> &dq : e_deques_) {
  46. while (dq.empty() == false) {
  47. std::shared_ptr<Edge> edge_ptr = dq.front();
  48. std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
  49. RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
  50. auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
  51. CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
  52. CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
  53. RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
  54. RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
  55. e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
  56. (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id());
  57. dq.pop_front();
  58. }
  59. }
  60. for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
  61. for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
  62. MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map);
  63. return Status::OK();
  64. }
  65. Status GraphLoader::InitAndLoad() {
  66. CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n");
  67. CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n");
  68. n_deques_.resize(num_workers_);
  69. e_deques_.resize(num_workers_);
  70. n_feature_maps_.resize(num_workers_);
  71. e_feature_maps_.resize(num_workers_);
  72. default_feature_maps_.resize(num_workers_);
  73. TaskGroup vg;
  74. shard_reader_ = std::make_unique<ShardReader>();
  75. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
  76. "Fail to open" + mr_path_);
  77. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
  78. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
  79. mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"];
  80. for (const std::string &key : keys_) {
  81. if (schema.find(key) == schema.end()) {
  82. RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
  83. }
  84. }
  85. // launching worker threads
  86. for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
  87. RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
  88. }
  89. // wait for threads to finish and check its return code
  90. vg.join_all(Task::WaitFlag::kBlocking);
  91. RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny());
  92. return Status::OK();
  93. }
  94. Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
  95. std::shared_ptr<Node> *node, NodeFeatureMap *feature_map,
  96. DefaultFeatureMap *default_feature) {
  97. NodeIdType node_id = col_jsn["first_id"];
  98. NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
  99. (*node) = std::make_shared<LocalNode>(node_id, node_type);
  100. std::vector<int32_t> indices;
  101. RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices));
  102. for (int32_t ind : indices) {
  103. std::shared_ptr<Tensor> tensor;
  104. RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
  105. RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
  106. (*feature_map)[node_type].insert(ind);
  107. if ((*default_feature)[ind] == nullptr) {
  108. std::shared_ptr<Tensor> zero_tensor;
  109. RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
  110. RETURN_IF_NOT_OK(zero_tensor->Zero());
  111. (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
  112. }
  113. }
  114. return Status::OK();
  115. }
  116. Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
  117. std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map,
  118. DefaultFeatureMap *default_feature) {
  119. EdgeIdType edge_id = col_jsn["first_id"];
  120. EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
  121. NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
  122. std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1);
  123. std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
  124. (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
  125. std::vector<int32_t> indices;
  126. RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices));
  127. for (int32_t ind : indices) {
  128. std::shared_ptr<Tensor> tensor;
  129. RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
  130. RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
  131. (*feature_map)[edge_type].insert(ind);
  132. if ((*default_feature)[ind] == nullptr) {
  133. std::shared_ptr<Tensor> zero_tensor;
  134. RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
  135. RETURN_IF_NOT_OK(zero_tensor->Zero());
  136. (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
  137. }
  138. }
  139. return Status::OK();
  140. }
  141. Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
  142. const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
  143. const unsigned char *data = nullptr;
  144. std::unique_ptr<unsigned char[]> data_ptr;
  145. uint64_t n_bytes = 0, col_type_size = 1;
  146. mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
  147. std::vector<int64_t> column_shape;
  148. MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
  149. key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
  150. CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
  151. if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
  152. RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible,
  153. std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
  154. std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data));
  155. return Status::OK();
  156. }
  157. Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
  158. const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
  159. const unsigned char *data = nullptr;
  160. std::unique_ptr<unsigned char[]> data_ptr;
  161. uint64_t n_bytes = 0, col_type_size = 1;
  162. mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
  163. std::vector<int64_t> column_shape;
  164. MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
  165. key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
  166. CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
  167. if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
  168. for (int i = 0; i < n_bytes; i += col_type_size) {
  169. int32_t feature_ind = -1;
  170. if (col_type == mindrecord::ColumnInt32) {
  171. feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
  172. } else if (col_type == mindrecord::ColumnInt64) {
  173. feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
  174. } else {
  175. RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
  176. }
  177. if (feature_ind >= 0) indices->push_back(feature_ind);
  178. }
  179. return Status::OK();
  180. }
  181. Status GraphLoader::WorkerEntry(int32_t worker_id) {
  182. // Handshake
  183. TaskManager::FindMe()->Post();
  184. auto ret = shard_reader_->GetNextById(row_id_++, worker_id);
  185. ShardTuple rows = ret.second;
  186. while (rows.empty() == false) {
  187. RETURN_IF_INTERRUPTED();
  188. for (const auto &tupled_row : rows) {
  189. std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
  190. mindrecord::json col_jsn = std::get<1>(tupled_row);
  191. std::string attr = col_jsn["attribute"];
  192. if (attr == "n") {
  193. std::shared_ptr<Node> node_ptr;
  194. RETURN_IF_NOT_OK(
  195. LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
  196. n_deques_[worker_id].emplace_back(node_ptr);
  197. } else if (attr == "e") {
  198. std::shared_ptr<Edge> edge_ptr;
  199. RETURN_IF_NOT_OK(
  200. LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
  201. e_deques_[worker_id].emplace_back(edge_ptr);
  202. } else {
  203. MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node.";
  204. }
  205. }
  206. auto rc = shard_reader_->GetNextById(row_id_++, worker_id);
  207. rows = rc.second;
  208. }
  209. return Status::OK();
  210. }
  211. void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
  212. DefaultFeatureMap *default_feature_map) {
  213. for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
  214. for (auto &m : n_feature_maps_[wkr_id]) {
  215. for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
  216. }
  217. for (auto &m : e_feature_maps_[wkr_id]) {
  218. for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
  219. }
  220. for (auto &m : default_feature_maps_[wkr_id]) {
  221. (*default_feature_map)[m.first] = m.second;
  222. }
  223. }
  224. n_feature_maps_.clear();
  225. e_feature_maps_.clear();
  226. }
  227. } // namespace gnn
  228. } // namespace dataset
  229. } // namespace mindspore