You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gnn_graph_test.cc 16 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <algorithm>
  17. #include <string>
  18. #include <map>
  19. #include <memory>
  20. #include <unordered_set>
  21. #include "common/common.h"
  22. #include "gtest/gtest.h"
  23. #include "minddata/dataset/util/status.h"
  24. #include "minddata/dataset/engine/gnn/node.h"
  25. #include "minddata/dataset/engine/gnn/graph_data_impl.h"
  26. #include "minddata/dataset/engine/gnn/graph_loader.h"
  27. using namespace mindspore::dataset;
  28. using namespace mindspore::dataset::gnn;
  29. #define print_int_vec(_i, _str) \
  30. do { \
  31. std::stringstream ss; \
  32. std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \
  33. MS_LOG(INFO) << _str << " " << ss.str(); \
  34. } while (false)
  35. class MindDataTestGNNGraph : public UT::Common {
  36. protected:
  37. MindDataTestGNNGraph() = default;
  38. using NumNeighborsMap = std::map<NodeIdType, uint32_t>;
  39. using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>;
  40. void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) {
  41. auto shape_vec = neighbors->shape().AsVector();
  42. uint32_t num_members = 1;
  43. for (size_t i = 1; i < shape_vec.size(); ++i) {
  44. num_members *= shape_vec[i];
  45. }
  46. uint32_t index = 0;
  47. NodeIdType src_node = 0;
  48. for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>();
  49. ++node_itr, ++index) {
  50. if (index % num_members == 0) {
  51. src_node = *node_itr;
  52. continue;
  53. }
  54. auto src_node_itr = node_neighbors.find(src_node);
  55. if (src_node_itr == node_neighbors.end()) {
  56. node_neighbors[src_node] = {{*node_itr, 1}};
  57. } else {
  58. auto nei_itr = src_node_itr->second.find(*node_itr);
  59. if (nei_itr == src_node_itr->second.end()) {
  60. src_node_itr->second[*node_itr] = 1;
  61. } else {
  62. src_node_itr->second[*node_itr] += 1;
  63. }
  64. }
  65. }
  66. }
  67. void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights,
  68. float deviation_ratio = 0.2) {
  69. EXPECT_EQ(number_neighbors.size(), weights.size());
  70. int index = 0;
  71. uint32_t pre_num = 0;
  72. WeightType pre_weight = 1;
  73. for (auto neighbor : number_neighbors) {
  74. if (pre_num != 0) {
  75. float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]);
  76. float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second);
  77. float target_upper = target_ratio * (1 + deviation_ratio);
  78. float target_lower = target_ratio * (1 - deviation_ratio);
  79. MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio)
  80. << " target_upper:" << std::to_string(target_upper)
  81. << " target_lower:" << std::to_string(target_lower);
  82. EXPECT_LE(current_ratio, target_upper);
  83. EXPECT_GE(current_ratio, target_lower);
  84. }
  85. pre_num = neighbor.second;
  86. pre_weight = weights[index];
  87. ++index;
  88. }
  89. }
  90. };
  91. TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) {
  92. std::string path = "data/mindrecord/testGraphData/testdata";
  93. GraphDataImpl graph(path, 1);
  94. Status s = graph.Init();
  95. EXPECT_TRUE(s.IsOk());
  96. std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208},
  97. {110, 201}, {204, 105}, {208, 108}};
  98. std::shared_ptr<Tensor> edges;
  99. s = graph.GetEdgesFromNodes(src_dst_list, &edges);
  100. EXPECT_TRUE(s.IsOk());
  101. EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]");
  102. }
  103. TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
  104. std::string path = "data/mindrecord/testGraphData/testdata";
  105. GraphDataImpl graph(path, 1);
  106. Status s = graph.Init();
  107. EXPECT_TRUE(s.IsOk());
  108. MetaInfo meta_info;
  109. s = graph.GetMetaInfo(&meta_info);
  110. EXPECT_TRUE(s.IsOk());
  111. EXPECT_TRUE(meta_info.node_type.size() == 2);
  112. std::shared_ptr<Tensor> nodes;
  113. s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
  114. EXPECT_TRUE(s.IsOk());
  115. std::vector<NodeIdType> node_list;
  116. for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
  117. node_list.push_back(*itr);
  118. if (node_list.size() >= 10) {
  119. break;
  120. }
  121. }
  122. std::shared_ptr<Tensor> neighbors;
  123. s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kNormal, &neighbors);
  124. EXPECT_TRUE(s.IsOk());
  125. EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
  126. TensorRow features;
  127. s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features);
  128. EXPECT_TRUE(s.IsOk());
  129. EXPECT_TRUE(features.size() == 4);
  130. EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>");
  131. EXPECT_TRUE(features[0]->ToString() ==
  132. "Tensor (shape: <10,5>, Type: int32)\n"
  133. "[[0,1,0,0,0],[1,0,0,0,1],[0,0,1,1,0],[0,0,0,0,0],[1,1,0,1,0],[0,0,0,0,1],[0,1,0,0,0],[0,0,0,1,1],[0,1,1,"
  134. "0,0],[0,1,0,1,0]]");
  135. EXPECT_TRUE(features[1]->shape().ToString() == "<10>");
  136. EXPECT_TRUE(features[1]->ToString() ==
  137. "Tensor (shape: <10>, Type: float32)\n[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]");
  138. EXPECT_TRUE(features[2]->shape().ToString() == "<10>");
  139. EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
  140. }
  141. TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) {
  142. std::string path = "data/mindrecord/testGraphData/testdata";
  143. GraphDataImpl graph(path, 1);
  144. Status s = graph.Init();
  145. EXPECT_TRUE(s.IsOk());
  146. MetaInfo meta_info;
  147. s = graph.GetMetaInfo(&meta_info);
  148. EXPECT_TRUE(s.IsOk());
  149. EXPECT_TRUE(meta_info.node_type.size() == 2);
  150. std::shared_ptr<Tensor> nodes;
  151. s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
  152. EXPECT_TRUE(s.IsOk());
  153. std::vector<NodeIdType> node_list;
  154. for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
  155. node_list.push_back(*itr);
  156. if (node_list.size() >= 10) {
  157. break;
  158. }
  159. }
  160. // Check COO format
  161. std::shared_ptr<Tensor> neighbors_coo;
  162. s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCoo, &neighbors_coo);
  163. EXPECT_TRUE(s.IsOk());
  164. EXPECT_TRUE(neighbors_coo->shape().ToString() == "<20,2>");
  165. EXPECT_TRUE(neighbors_coo->ToString() ==
  166. "Tensor (shape: <20,2>, Type: int32)\n"
  167. "[[101,201],[101,205],[101,206],[102,201],[102,202],[103,203],[103,205],[103,206],[103,207],[103,208],"
  168. "[105,204],[106,202],[106,203],[107,201],[107,203],[107,207],[108,208],[109,210],[110,201],[110,210]]");
  169. // Check CSR format
  170. std::shared_ptr<Tensor> neighbors_csr;
  171. s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCsr, &neighbors_csr);
  172. EXPECT_TRUE(s.IsOk());
  173. EXPECT_TRUE(neighbors_csr->shape().ToString() == "<30>");
  174. EXPECT_TRUE(
  175. neighbors_csr->ToString() ==
  176. "Tensor (shape: <30>, Type: int32)\n"
  177. "[0,3,5,10,10,11,13,16,17,18,201,205,206,201,202,203,205,206,207,208,204,202,203,201,203,207,208,210,201,210]");
  178. }
  179. TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
  180. std::string path = "data/mindrecord/testGraphData/testdata";
  181. GraphDataImpl graph(path, 1);
  182. Status s = graph.Init();
  183. EXPECT_TRUE(s.IsOk());
  184. MetaInfo meta_info;
  185. s = graph.GetMetaInfo(&meta_info);
  186. EXPECT_TRUE(s.IsOk());
  187. EXPECT_TRUE(meta_info.node_type.size() == 2);
  188. std::shared_ptr<Tensor> edges;
  189. s = graph.GetAllEdges(meta_info.edge_type[0], &edges);
  190. EXPECT_TRUE(s.IsOk());
  191. std::vector<EdgeIdType> edge_list;
  192. edge_list.resize(edges->Size());
  193. std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(),
  194. [](const EdgeIdType edge) { return edge; });
  195. TensorRow edge_features;
  196. s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features);
  197. EXPECT_TRUE(s.IsOk());
  198. EXPECT_TRUE(edge_features[0]->ToString() ==
  199. "Tensor (shape: <40>, Type: int32)\n"
  200. "[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]");
  201. EXPECT_TRUE(edge_features[1]->ToString() ==
  202. "Tensor (shape: <40>, Type: float32)\n"
  203. "[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2."
  204. "7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]");
  205. std::shared_ptr<Tensor> nodes;
  206. s = graph.GetNodesFromEdges(edge_list, &nodes);
  207. EXPECT_TRUE(s.IsOk());
  208. std::unordered_set<NodeIdType> node_set;
  209. std::vector<NodeIdType> node_list;
  210. int index = 0;
  211. for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
  212. index++;
  213. if (index % 2 == 0) {
  214. continue;
  215. }
  216. node_set.emplace(*itr);
  217. if (node_set.size() >= 5) {
  218. break;
  219. }
  220. }
  221. node_list.resize(node_set.size());
  222. std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
  223. std::shared_ptr<Tensor> neighbors;
  224. {
  225. MS_LOG(INFO) << "Test random sampling.";
  226. NodeNeighborsMap number_neighbors;
  227. int count = 0;
  228. while (count < 1000) {
  229. neighbors.reset();
  230. s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
  231. EXPECT_TRUE(s.IsOk());
  232. EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
  233. ParsingNeighbors(neighbors, number_neighbors);
  234. ++count;
  235. }
  236. CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1});
  237. }
  238. {
  239. MS_LOG(INFO) << "Test edge weight sampling.";
  240. NodeNeighborsMap number_neighbors;
  241. int count = 0;
  242. while (count < 1000) {
  243. neighbors.reset();
  244. s =
  245. graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors);
  246. EXPECT_TRUE(s.IsOk());
  247. EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
  248. ParsingNeighbors(neighbors, number_neighbors);
  249. ++count;
  250. }
  251. CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8});
  252. }
  253. neighbors.reset();
  254. s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]},
  255. SamplingStrategy::kRandom, &neighbors);
  256. EXPECT_TRUE(s.IsOk());
  257. EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
  258. neighbors.reset();
  259. s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
  260. {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]},
  261. SamplingStrategy::kRandom, &neighbors);
  262. EXPECT_TRUE(s.IsOk());
  263. EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
  264. neighbors.reset();
  265. s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
  266. EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
  267. neighbors.reset();
  268. s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
  269. EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
  270. neighbors.reset();
  271. s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]},
  272. SamplingStrategy::kRandom, &neighbors);
  273. EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
  274. neighbors.reset();
  275. s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors);
  276. EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
  277. neighbors.reset();
  278. s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]},
  279. SamplingStrategy::kRandom, &neighbors);
  280. EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
  281. std::string::npos);
  282. neighbors.reset();
  283. s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
  284. EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
  285. }
  286. TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
  287. std::string path = "data/mindrecord/testGraphData/testdata";
  288. GraphDataImpl graph(path, 1);
  289. Status s = graph.Init();
  290. EXPECT_TRUE(s.IsOk());
  291. MetaInfo meta_info;
  292. s = graph.GetMetaInfo(&meta_info);
  293. EXPECT_TRUE(s.IsOk());
  294. EXPECT_TRUE(meta_info.node_type.size() == 2);
  295. std::shared_ptr<Tensor> nodes;
  296. s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
  297. EXPECT_TRUE(s.IsOk());
  298. std::vector<NodeIdType> node_list;
  299. for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
  300. node_list.push_back(*itr);
  301. if (node_list.size() >= 10) {
  302. break;
  303. }
  304. }
  305. std::shared_ptr<Tensor> neg_neighbors;
  306. s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors);
  307. EXPECT_TRUE(s.IsOk());
  308. EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>");
  309. neg_neighbors.reset();
  310. s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors);
  311. EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
  312. neg_neighbors.reset();
  313. s = graph.GetNegSampledNeighbors({-1, 1}, 3, meta_info.node_type[1], &neg_neighbors);
  314. EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
  315. neg_neighbors.reset();
  316. s = graph.GetNegSampledNeighbors(node_list, 50, meta_info.node_type[1], &neg_neighbors);
  317. EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
  318. neg_neighbors.reset();
  319. s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
  320. EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
  321. }
  322. TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
  323. std::string path = "data/mindrecord/testGraphData/sns";
  324. GraphDataImpl graph(path, 1);
  325. Status s = graph.Init();
  326. EXPECT_TRUE(s.IsOk());
  327. MetaInfo meta_info;
  328. s = graph.GetMetaInfo(&meta_info);
  329. EXPECT_TRUE(s.IsOk());
  330. std::shared_ptr<Tensor> nodes;
  331. s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
  332. EXPECT_TRUE(s.IsOk());
  333. std::vector<NodeIdType> node_list;
  334. for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
  335. node_list.push_back(*itr);
  336. }
  337. print_int_vec(node_list, "node list ");
  338. std::vector<NodeType> meta_path(59, 1);
  339. std::shared_ptr<Tensor> walk_path;
  340. s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
  341. EXPECT_TRUE(s.IsOk());
  342. EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
  343. }
  344. TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
  345. std::string path = "data/mindrecord/testGraphData/sns";
  346. GraphDataImpl graph(path, 1);
  347. Status s = graph.Init();
  348. EXPECT_TRUE(s.IsOk());
  349. MetaInfo meta_info;
  350. s = graph.GetMetaInfo(&meta_info);
  351. EXPECT_TRUE(s.IsOk());
  352. std::shared_ptr<Tensor> nodes;
  353. s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
  354. EXPECT_TRUE(s.IsOk());
  355. std::vector<NodeIdType> node_list;
  356. for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
  357. node_list.push_back(*itr);
  358. }
  359. print_int_vec(node_list, "node list ");
  360. std::vector<NodeType> meta_path(59, 1);
  361. std::shared_ptr<Tensor> walk_path;
  362. s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path);
  363. EXPECT_TRUE(s.IsOk());
  364. EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
  365. }