You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.cc 26 kB

5 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/gnn/graph.h"
  17. #include <algorithm>
  18. #include <functional>
  19. #include <iterator>
  20. #include <numeric>
  21. #include <utility>
  22. #include "dataset/core/tensor_shape.h"
  23. #include "dataset/util/random.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. namespace gnn {
  27. Graph::Graph(std::string dataset_file, int32_t num_workers)
  28. : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) {
  29. rnd_.seed(GetSeed());
  30. MS_LOG(INFO) << "num_workers:" << num_workers;
  31. }
  32. Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
  33. auto itr = node_type_map_.find(node_type);
  34. if (itr == node_type_map_.end()) {
  35. std::string err_msg = "Invalid node type:" + std::to_string(node_type);
  36. RETURN_STATUS_UNEXPECTED(err_msg);
  37. } else {
  38. RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out));
  39. }
  40. return Status::OK();
  41. }
  42. template <typename T>
  43. Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type,
  44. std::shared_ptr<Tensor> *out) {
  45. if (!type.IsCompatible<T>()) {
  46. RETURN_STATUS_UNEXPECTED("Data type not compatible");
  47. }
  48. if (data.empty()) {
  49. RETURN_STATUS_UNEXPECTED("Input data is empty");
  50. }
  51. std::shared_ptr<Tensor> tensor;
  52. size_t m = data.size();
  53. size_t n = data[0].size();
  54. RETURN_IF_NOT_OK(Tensor::CreateTensor(
  55. &tensor, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(m), static_cast<dsize_t>(n)}), type, nullptr));
  56. auto ptr = tensor->begin<T>();
  57. for (const auto &id_m : data) {
  58. CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size");
  59. for (const auto &id_n : id_m) {
  60. *ptr = id_n;
  61. ptr++;
  62. }
  63. }
  64. tensor->Squeeze();
  65. *out = std::move(tensor);
  66. return Status::OK();
  67. }
  68. template <typename T>
  69. Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) {
  70. if (!data || data->empty()) {
  71. RETURN_STATUS_UNEXPECTED("Input data is empty");
  72. }
  73. for (std::vector<T> &vec : *data) {
  74. size_t size = vec.size();
  75. if (size > max_size) {
  76. RETURN_STATUS_UNEXPECTED("The max_size parameter is abnormal");
  77. } else {
  78. for (size_t i = 0; i < (max_size - size); ++i) {
  79. vec.push_back(default_value);
  80. }
  81. }
  82. }
  83. return Status::OK();
  84. }
  85. Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
  86. auto itr = edge_type_map_.find(edge_type);
  87. if (itr == edge_type_map_.end()) {
  88. std::string err_msg = "Invalid edge type:" + std::to_string(edge_type);
  89. RETURN_STATUS_UNEXPECTED(err_msg);
  90. } else {
  91. RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>({itr->second}, DataType(DataType::DE_INT32), out));
  92. }
  93. return Status::OK();
  94. }
  95. Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) {
  96. if (edge_list.empty()) {
  97. RETURN_STATUS_UNEXPECTED("Input edge_list is empty");
  98. }
  99. std::vector<std::vector<NodeIdType>> node_list;
  100. node_list.reserve(edge_list.size());
  101. for (const auto &edge_id : edge_list) {
  102. auto itr = edge_id_map_.find(edge_id);
  103. if (itr == edge_id_map_.end()) {
  104. std::string err_msg = "Invalid edge id:" + std::to_string(edge_id);
  105. RETURN_STATUS_UNEXPECTED(err_msg);
  106. } else {
  107. std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> nodes;
  108. RETURN_IF_NOT_OK(itr->second->GetNode(&nodes));
  109. node_list.push_back({nodes.first->id(), nodes.second->id()});
  110. }
  111. }
  112. RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(node_list, DataType(DataType::DE_INT32), out));
  113. return Status::OK();
  114. }
  115. Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
  116. std::shared_ptr<Tensor> *out) {
  117. if (node_list.empty()) {
  118. RETURN_STATUS_UNEXPECTED("Input node_list is empty.");
  119. }
  120. if (node_type_map_.find(neighbor_type) == node_type_map_.end()) {
  121. std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type);
  122. RETURN_STATUS_UNEXPECTED(err_msg);
  123. }
  124. std::vector<std::vector<NodeIdType>> neighbors;
  125. size_t max_neighbor_num = 0;
  126. neighbors.resize(node_list.size());
  127. for (size_t i = 0; i < node_list.size(); ++i) {
  128. std::shared_ptr<Node> node;
  129. RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node));
  130. RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i]));
  131. max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size();
  132. }
  133. RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId));
  134. RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors, DataType(DataType::DE_INT32), out));
  135. return Status::OK();
  136. }
  137. Status Graph::CheckSamplesNum(NodeIdType samples_num) {
  138. NodeIdType all_nodes_number =
  139. std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0,
  140. [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); });
  141. if ((samples_num < 1) || (samples_num > all_nodes_number)) {
  142. std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) +
  143. ", got " + std::to_string(samples_num);
  144. RETURN_STATUS_UNEXPECTED(err_msg);
  145. }
  146. return Status::OK();
  147. }
  148. Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
  149. const std::vector<NodeIdType> &neighbor_nums,
  150. const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
  151. CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
  152. CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
  153. "The sizes of neighbor_nums and neighbor_types are inconsistent.");
  154. for (const auto &num : neighbor_nums) {
  155. RETURN_IF_NOT_OK(CheckSamplesNum(num));
  156. }
  157. for (const auto &type : neighbor_types) {
  158. if (node_type_map_.find(type) == node_type_map_.end()) {
  159. std::string err_msg = "Invalid neighbor type:" + std::to_string(type);
  160. RETURN_STATUS_UNEXPECTED(err_msg);
  161. }
  162. }
  163. std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size());
  164. for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
  165. std::shared_ptr<Node> input_node;
  166. RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node));
  167. neighbors_vec[node_idx].emplace_back(node_list[node_idx]);
  168. std::vector<NodeIdType> input_list = {node_list[node_idx]};
  169. for (size_t i = 0; i < neighbor_nums.size(); ++i) {
  170. std::vector<NodeIdType> neighbors;
  171. neighbors.reserve(input_list.size() * neighbor_nums[i]);
  172. for (const auto &node_id : input_list) {
  173. if (node_id == kDefaultNodeId) {
  174. for (int32_t j = 0; j < neighbor_nums[i]; ++j) {
  175. neighbors.emplace_back(kDefaultNodeId);
  176. }
  177. } else {
  178. std::shared_ptr<Node> node;
  179. RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node));
  180. std::vector<NodeIdType> out;
  181. RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out));
  182. neighbors.insert(neighbors.end(), out.begin(), out.end());
  183. }
  184. }
  185. neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end());
  186. input_list = std::move(neighbors);
  187. }
  188. }
  189. RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out));
  190. return Status::OK();
  191. }
  192. Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data,
  193. int32_t samples_num, std::vector<NodeIdType> *out_samples) {
  194. CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
  195. std::vector<NodeIdType> shuffled_id(data.size());
  196. std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
  197. std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
  198. for (const auto &index : shuffled_id) {
  199. if (exclude_data.find(data[index]) != exclude_data.end()) {
  200. continue;
  201. }
  202. out_samples->emplace_back(data[index]);
  203. if (out_samples->size() >= samples_num) {
  204. break;
  205. }
  206. }
  207. return Status::OK();
  208. }
  209. Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
  210. NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
  211. CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
  212. RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
  213. if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) {
  214. std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type);
  215. RETURN_STATUS_UNEXPECTED(err_msg);
  216. }
  217. std::vector<std::vector<NodeIdType>> neighbors_vec;
  218. neighbors_vec.resize(node_list.size());
  219. for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
  220. std::shared_ptr<Node> node;
  221. RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node));
  222. std::vector<NodeIdType> neighbors;
  223. RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors));
  224. std::unordered_set<NodeIdType> exclude_node;
  225. std::transform(neighbors.begin(), neighbors.end(),
  226. std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_node, exclude_node.begin()),
  227. [](const NodeIdType node) { return node; });
  228. auto itr = node_type_map_.find(neg_neighbor_type);
  229. if (itr == node_type_map_.end()) {
  230. std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type);
  231. RETURN_STATUS_UNEXPECTED(err_msg);
  232. } else {
  233. neighbors_vec[node_idx].emplace_back(node->id());
  234. if (itr->second.size() > exclude_node.size()) {
  235. while (neighbors_vec[node_idx].size() < samples_num + 1) {
  236. RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(),
  237. &neighbors_vec[node_idx]));
  238. }
  239. } else {
  240. MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
  241. << " neg_neighbor_type:" << neg_neighbor_type;
  242. // If there are no negative neighbors, they are filled with kDefaultNodeId
  243. for (int32_t i = 0; i < samples_num; ++i) {
  244. neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
  245. }
  246. }
  247. }
  248. }
  249. RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out));
  250. return Status::OK();
  251. }
  252. Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
  253. float step_home_param, float step_away_param, NodeIdType default_node,
  254. std::shared_ptr<Tensor> *out) {
  255. RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node));
  256. std::vector<std::vector<NodeIdType>> walks;
  257. RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks));
  258. RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({walks}, DataType(DataType::DE_INT32), out));
  259. return Status::OK();
  260. }
  261. Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
  262. auto itr = default_feature_map_.find(feature_type);
  263. if (itr == default_feature_map_.end()) {
  264. std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
  265. RETURN_STATUS_UNEXPECTED(err_msg);
  266. } else {
  267. *out_feature = itr->second;
  268. }
  269. return Status::OK();
  270. }
  271. Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
  272. TensorRow *out) {
  273. if (!nodes || nodes->Size() == 0) {
  274. RETURN_STATUS_UNEXPECTED("Input nodes is empty");
  275. }
  276. CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty");
  277. TensorRow tensors;
  278. for (const auto &f_type : feature_types) {
  279. std::shared_ptr<Feature> default_feature;
  280. // If no feature can be obtained, fill in the default value
  281. RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature));
  282. TensorShape shape(default_feature->Value()->shape());
  283. auto shape_vec = nodes->shape().AsVector();
  284. dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>());
  285. shape = shape.PrependDim(size);
  286. std::shared_ptr<Tensor> fea_tensor;
  287. RETURN_IF_NOT_OK(
  288. Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr));
  289. dsize_t index = 0;
  290. for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) {
  291. std::shared_ptr<Feature> feature;
  292. if (*node_itr == kDefaultNodeId) {
  293. feature = default_feature;
  294. } else {
  295. std::shared_ptr<Node> node;
  296. RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node));
  297. if (!node->GetFeatures(f_type, &feature).IsOk()) {
  298. feature = default_feature;
  299. }
  300. }
  301. RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value()));
  302. index++;
  303. }
  304. TensorShape reshape(nodes->shape());
  305. for (auto s : default_feature->Value()->shape().AsVector()) {
  306. reshape = reshape.AppendDim(s);
  307. }
  308. RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape));
  309. fea_tensor->Squeeze();
  310. tensors.push_back(fea_tensor);
  311. }
  312. *out = std::move(tensors);
  313. return Status::OK();
  314. }
  315. Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
  316. TensorRow *out) {
  317. return Status::OK();
  318. }
  319. Status Graph::Init() {
  320. RETURN_IF_NOT_OK(LoadNodeAndEdge());
  321. return Status::OK();
  322. }
  323. Status Graph::GetMetaInfo(MetaInfo *meta_info) {
  324. meta_info->node_type.resize(node_type_map_.size());
  325. std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(),
  326. [](auto itr) { return itr.first; });
  327. std::sort(meta_info->node_type.begin(), meta_info->node_type.end());
  328. meta_info->edge_type.resize(edge_type_map_.size());
  329. std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(),
  330. [](auto itr) { return itr.first; });
  331. std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end());
  332. for (const auto &node : node_type_map_) {
  333. meta_info->node_num[node.first] = node.second.size();
  334. }
  335. for (const auto &edge : edge_type_map_) {
  336. meta_info->edge_num[edge.first] = edge.second.size();
  337. }
  338. for (const auto &node_feature : node_feature_map_) {
  339. for (auto type : node_feature.second) {
  340. meta_info->node_feature_type.emplace_back(type);
  341. }
  342. }
  343. std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end());
  344. auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end());
  345. meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end());
  346. for (const auto &edge_feature : edge_feature_map_) {
  347. for (const auto &type : edge_feature.second) {
  348. meta_info->edge_feature_type.emplace_back(type);
  349. }
  350. }
  351. std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end());
  352. auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end());
  353. meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end());
  354. return Status::OK();
  355. }
  356. Status Graph::GraphInfo(py::dict *out) {
  357. MetaInfo meta_info;
  358. RETURN_IF_NOT_OK(GetMetaInfo(&meta_info));
  359. (*out)["node_type"] = py::cast(meta_info.node_type);
  360. (*out)["edge_type"] = py::cast(meta_info.edge_type);
  361. (*out)["node_num"] = py::cast(meta_info.node_num);
  362. (*out)["edge_num"] = py::cast(meta_info.edge_num);
  363. (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type);
  364. (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type);
  365. return Status::OK();
  366. }
  367. Status Graph::LoadNodeAndEdge() {
  368. GraphLoader gl(dataset_file_, num_workers_);
  369. // ask graph_loader to load everything into memory
  370. RETURN_IF_NOT_OK(gl.InitAndLoad());
  371. // get all maps
  372. RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_,
  373. &node_feature_map_, &edge_feature_map_, &default_feature_map_));
  374. return Status::OK();
  375. }
  376. Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
  377. auto itr = node_id_map_.find(id);
  378. if (itr == node_id_map_.end()) {
  379. std::string err_msg = "Invalid node id:" + std::to_string(id);
  380. RETURN_STATUS_UNEXPECTED(err_msg);
  381. } else {
  382. *node = itr->second;
  383. }
  384. return Status::OK();
  385. }
  386. Graph::RandomWalkBase::RandomWalkBase(Graph *graph)
  387. : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}
  388. Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
  389. float step_home_param, float step_away_param, const NodeIdType default_node,
  390. int32_t num_walks, int32_t num_workers) {
  391. node_list_ = node_list;
  392. if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
  393. std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) +
  394. ". The size of input path is " + std::to_string(meta_path.size());
  395. RETURN_STATUS_UNEXPECTED(err_msg);
  396. }
  397. meta_path_ = meta_path;
  398. if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) {
  399. std::string err_msg = "Failed, step_home_param and step_away_param required greater than " +
  400. std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) +
  401. ", step_away_param: " + std::to_string(step_away_param);
  402. RETURN_STATUS_UNEXPECTED(err_msg);
  403. }
  404. step_home_param_ = step_home_param;
  405. step_away_param_ = step_away_param;
  406. default_node_ = default_node;
  407. num_walks_ = num_walks;
  408. num_workers_ = num_workers;
  409. return Status::OK();
  410. }
  411. Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) {
  412. // Simulate a random walk starting from start node.
  413. auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector
  414. // walk simulate
  415. while (walk.size() - 1 < meta_path_.size()) {
  416. // current nodE
  417. auto cur_node_id = walk.back();
  418. std::shared_ptr<Node> cur_node;
  419. RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node));
  420. // current neighbors
  421. std::vector<NodeIdType> cur_neighbors;
  422. RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true));
  423. std::sort(cur_neighbors.begin(), cur_neighbors.end());
  424. // break if no neighbors
  425. if (cur_neighbors.empty()) {
  426. break;
  427. }
  428. // walk by the fist node, then by the previous 2 nodes
  429. std::shared_ptr<StochasticIndex> stochastic_index;
  430. if (walk.size() == 1) {
  431. RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index));
  432. } else {
  433. NodeIdType prev_node_id = walk[walk.size() - 2];
  434. RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index));
  435. }
  436. NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)];
  437. walk.push_back(next_node_id);
  438. }
  439. while (walk.size() - 1 < meta_path_.size()) {
  440. walk.push_back(default_node_);
  441. }
  442. *walk_path = std::move(walk);
  443. return Status::OK();
  444. }
  445. Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
  446. // Repeatedly simulate random walks from each node
  447. std::vector<uint32_t> permutation(node_list_.size());
  448. std::iota(permutation.begin(), permutation.end(), 0);
  449. for (int32_t i = 0; i < num_walks_; i++) {
  450. unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
  451. std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed));
  452. for (const auto &i_perm : permutation) {
  453. std::vector<NodeIdType> walk;
  454. RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk));
  455. walks->push_back(walk);
  456. }
  457. }
  458. return Status::OK();
  459. }
  460. Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
  461. std::shared_ptr<StochasticIndex> *node_probability) {
  462. // Generate alias nodes
  463. std::shared_ptr<Node> node;
  464. graph_->GetNodeByNodeId(node_id, &node);
  465. std::vector<NodeIdType> neighbors;
  466. RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true));
  467. std::sort(neighbors.begin(), neighbors.end());
  468. auto non_normalized_probability = std::vector<float>(neighbors.size(), 1.0);
  469. *node_probability =
  470. std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
  471. return Status::OK();
  472. }
  473. Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
  474. std::shared_ptr<StochasticIndex> *edge_probability) {
  475. // Get the alias edge setup lists for a given edge.
  476. std::shared_ptr<Node> src_node;
  477. graph_->GetNodeByNodeId(src, &src_node);
  478. std::vector<NodeIdType> src_neighbors;
  479. RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true));
  480. std::shared_ptr<Node> dst_node;
  481. graph_->GetNodeByNodeId(dst, &dst_node);
  482. std::vector<NodeIdType> dst_neighbors;
  483. RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true));
  484. std::sort(dst_neighbors.begin(), dst_neighbors.end());
  485. std::vector<float> non_normalized_probability;
  486. for (const auto &dst_nbr : dst_neighbors) {
  487. if (dst_nbr == src) {
  488. non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
  489. continue;
  490. }
  491. auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr);
  492. if (it != src_neighbors.end()) {
  493. // stay close, this node connect both src and dst
  494. non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight']
  495. } else {
  496. // step far away
  497. non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
  498. }
  499. }
  500. *edge_probability =
  501. std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
  502. return Status::OK();
  503. }
  504. StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) {
  505. uint32_t K = probability.size();
  506. std::vector<int32_t> switch_to_large_index(K, 0);
  507. std::vector<float> weight(K, .0);
  508. std::vector<int32_t> smaller;
  509. std::vector<int32_t> larger;
  510. auto random_device = GetRandomDevice();
  511. std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon);
  512. float accumulate_threshold = 0.0;
  513. for (uint32_t i = 0; i < K; i++) {
  514. float threshold_one = distribution(random_device);
  515. accumulate_threshold += threshold_one;
  516. weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold;
  517. weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i);
  518. }
  519. while ((!smaller.empty()) && (!larger.empty())) {
  520. uint32_t small = smaller.back();
  521. smaller.pop_back();
  522. uint32_t large = larger.back();
  523. larger.pop_back();
  524. switch_to_large_index[small] = large;
  525. weight[large] = weight[large] + weight[small] - 1.0;
  526. weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large);
  527. }
  528. return StochasticIndex(switch_to_large_index, weight);
  529. }
  530. uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) {
  531. auto switch_to_large_index = stochastic_index.first;
  532. auto weight = stochastic_index.second;
  533. const uint32_t size_of_index = switch_to_large_index.size();
  534. auto random_device = GetRandomDevice();
  535. std::uniform_real_distribution<> distribution(0.0, 1.0);
  536. // Generate random integer between [0, K)
  537. uint32_t random_idx = std::floor(distribution(random_device) * size_of_index);
  538. if (distribution(random_device) < weight[random_idx]) {
  539. return random_idx;
  540. }
  541. return switch_to_large_index[random_idx];
  542. }
  543. template <typename T>
  544. std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) {
  545. float sum_probability =
  546. 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0);
  547. if (sum_probability < kGnnEpsilon) {
  548. sum_probability = 1.0;
  549. }
  550. std::vector<float> normalized_probability;
  551. std::transform(non_normalized_probability.begin(), non_normalized_probability.end(),
  552. std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; });
  553. return normalized_probability;
  554. }
  555. } // namespace gnn
  556. } // namespace dataset
  557. } // namespace mindspore