You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_loader.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <future>
  17. #include <tuple>
  18. #include <utility>
  19. #include "dataset/engine/gnn/graph_loader.h"
  20. #include "mindspore/ccsrc/mindrecord/include/shard_error.h"
  21. #include "dataset/engine/gnn/local_edge.h"
  22. #include "dataset/engine/gnn/local_node.h"
  23. using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
  24. namespace mindspore {
  25. namespace dataset {
  26. namespace gnn {
  27. using mindrecord::MSRStatus;
  28. GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
  29. : mr_path_(mr_filepath),
  30. num_workers_(num_workers),
  31. row_id_(0),
  32. keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
  33. Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
  34. EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
  35. EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) {
  36. for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
  37. while (dq.empty() == false) {
  38. std::shared_ptr<Node> node_ptr = dq.front();
  39. n_id_map->insert({node_ptr->id(), node_ptr});
  40. (*n_type_map)[node_ptr->type()].push_back(node_ptr->id());
  41. dq.pop_front();
  42. }
  43. }
  44. for (std::deque<std::shared_ptr<Edge>> &dq : e_deques_) {
  45. while (dq.empty() == false) {
  46. std::shared_ptr<Edge> edge_ptr = dq.front();
  47. std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
  48. RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
  49. auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
  50. CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
  51. CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
  52. RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
  53. RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
  54. e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
  55. (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id());
  56. dq.pop_front();
  57. }
  58. }
  59. for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
  60. for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
  61. MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map);
  62. return Status::OK();
  63. }
  64. Status GraphLoader::InitAndLoad() {
  65. CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n");
  66. CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n");
  67. n_deques_.resize(num_workers_);
  68. e_deques_.resize(num_workers_);
  69. n_feature_maps_.resize(num_workers_);
  70. e_feature_maps_.resize(num_workers_);
  71. default_feature_maps_.resize(num_workers_);
  72. std::vector<std::future<Status>> r_codes(num_workers_);
  73. shard_reader_ = std::make_unique<ShardReader>();
  74. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
  75. "Fail to open" + mr_path_);
  76. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
  77. CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
  78. mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"];
  79. for (const std::string &key : keys_) {
  80. if (schema.find(key) == schema.end()) {
  81. RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
  82. }
  83. }
  84. // launching worker threads
  85. for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
  86. r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id);
  87. }
  88. // wait for threads to finish and check its return code
  89. for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
  90. RETURN_IF_NOT_OK(r_codes[wkr_id].get());
  91. }
  92. return Status::OK();
  93. }
  94. Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
  95. std::shared_ptr<Node> *node, NodeFeatureMap *feature_map,
  96. DefaultFeatureMap *default_feature) {
  97. NodeIdType node_id = col_jsn["first_id"];
  98. NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
  99. (*node) = std::make_shared<LocalNode>(node_id, node_type);
  100. std::vector<int32_t> indices;
  101. RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices));
  102. for (int32_t ind : indices) {
  103. std::shared_ptr<Tensor> tensor;
  104. RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
  105. RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
  106. (*feature_map)[node_type].insert(ind);
  107. if ((*default_feature)[ind] == nullptr) {
  108. std::shared_ptr<Tensor> zero_tensor;
  109. RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
  110. RETURN_IF_NOT_OK(zero_tensor->Zero());
  111. (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
  112. }
  113. }
  114. return Status::OK();
  115. }
  116. Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
  117. std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map,
  118. DefaultFeatureMap *default_feature) {
  119. EdgeIdType edge_id = col_jsn["first_id"];
  120. EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
  121. NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
  122. std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1);
  123. std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
  124. (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
  125. std::vector<int32_t> indices;
  126. RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices));
  127. for (int32_t ind : indices) {
  128. std::shared_ptr<Tensor> tensor;
  129. RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
  130. RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
  131. (*feature_map)[edge_type].insert(ind);
  132. if ((*default_feature)[ind] == nullptr) {
  133. std::shared_ptr<Tensor> zero_tensor;
  134. RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
  135. RETURN_IF_NOT_OK(zero_tensor->Zero());
  136. (*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
  137. }
  138. }
  139. return Status::OK();
  140. }
  141. Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
  142. const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
  143. const unsigned char *data = nullptr;
  144. std::unique_ptr<unsigned char[]> data_ptr;
  145. uint64_t n_bytes = 0, col_type_size = 1;
  146. mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
  147. std::vector<int64_t> column_shape;
  148. MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
  149. key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
  150. CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
  151. if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
  152. RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible,
  153. std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
  154. std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data));
  155. return Status::OK();
  156. }
  157. Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
  158. const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
  159. const unsigned char *data = nullptr;
  160. std::unique_ptr<unsigned char[]> data_ptr;
  161. uint64_t n_bytes = 0, col_type_size = 1;
  162. mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
  163. std::vector<int64_t> column_shape;
  164. MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
  165. key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
  166. CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
  167. if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
  168. for (int i = 0; i < n_bytes; i += col_type_size) {
  169. int32_t feature_ind = -1;
  170. if (col_type == mindrecord::ColumnInt32) {
  171. feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
  172. } else if (col_type == mindrecord::ColumnInt64) {
  173. feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
  174. } else {
  175. RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
  176. }
  177. if (feature_ind >= 0) indices->push_back(feature_ind);
  178. }
  179. return Status::OK();
  180. }
  181. Status GraphLoader::WorkerEntry(int32_t worker_id) {
  182. ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id);
  183. while (rows.empty() == false) {
  184. for (const auto &tupled_row : rows) {
  185. std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
  186. mindrecord::json col_jsn = std::get<1>(tupled_row);
  187. std::string attr = col_jsn["attribute"];
  188. if (attr == "n") {
  189. std::shared_ptr<Node> node_ptr;
  190. RETURN_IF_NOT_OK(
  191. LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
  192. n_deques_[worker_id].emplace_back(node_ptr);
  193. } else if (attr == "e") {
  194. std::shared_ptr<Edge> edge_ptr;
  195. RETURN_IF_NOT_OK(
  196. LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
  197. e_deques_[worker_id].emplace_back(edge_ptr);
  198. } else {
  199. MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node.";
  200. }
  201. }
  202. rows = shard_reader_->GetNextById(row_id_++, worker_id);
  203. }
  204. return Status::OK();
  205. }
  206. void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
  207. DefaultFeatureMap *default_feature_map) {
  208. for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
  209. for (auto &m : n_feature_maps_[wkr_id]) {
  210. for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
  211. }
  212. for (auto &m : e_feature_maps_[wkr_id]) {
  213. for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
  214. }
  215. for (auto &m : default_feature_maps_[wkr_id]) {
  216. (*default_feature_map)[m.first] = m.second;
  217. }
  218. }
  219. n_feature_maps_.clear();
  220. e_feature_maps_.clear();
  221. }
  222. } // namespace gnn
  223. } // namespace dataset
  224. } // namespace mindspore