|
|
|
@@ -22,6 +22,7 @@ |
|
|
|
#include "mindspore/ccsrc/mindrecord/include/shard_error.h" |
|
|
|
#include "dataset/engine/gnn/local_edge.h" |
|
|
|
#include "dataset/engine/gnn/local_node.h" |
|
|
|
#include "dataset/util/task_manager.h" |
|
|
|
|
|
|
|
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>; |
|
|
|
|
|
|
|
@@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() { |
|
|
|
n_feature_maps_.resize(num_workers_); |
|
|
|
e_feature_maps_.resize(num_workers_); |
|
|
|
default_feature_maps_.resize(num_workers_); |
|
|
|
std::vector<std::future<Status>> r_codes(num_workers_); |
|
|
|
TaskGroup vg; |
|
|
|
|
|
|
|
shard_reader_ = std::make_unique<ShardReader>(); |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, |
|
|
|
@@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() { |
|
|
|
|
|
|
|
// launching worker threads |
|
|
|
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { |
|
|
|
r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id); |
|
|
|
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); |
|
|
|
} |
|
|
|
// wait for threads to finish and check its return code |
|
|
|
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { |
|
|
|
RETURN_IF_NOT_OK(r_codes[wkr_id].get()); |
|
|
|
} |
|
|
|
vg.join_all(Task::WaitFlag::kBlocking); |
|
|
|
RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny()); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<u |
|
|
|
} |
|
|
|
|
|
|
|
Status GraphLoader::WorkerEntry(int32_t worker_id) { |
|
|
|
// Handshake |
|
|
|
TaskManager::FindMe()->Post(); |
|
|
|
ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id); |
|
|
|
while (rows.empty() == false) { |
|
|
|
RETURN_IF_INTERRUPTED(); |
|
|
|
for (const auto &tupled_row : rows) { |
|
|
|
std::vector<uint8_t> col_blob = std::get<0>(tupled_row); |
|
|
|
mindrecord::json col_jsn = std::get<1>(tupled_row); |
|
|
|
|