You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.GraphData.rst 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. Class mindspore.dataset.GraphData(dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True)
  2. 从共享文件和数据库中读取用于GNN训练的图数据集。
  3. **参数:**
  4. - **dataset_file** (str):数据集文件路径。
  5. - **num_parallel_workers** (int, 可选):读取数据的工作线程数(默认为None)。
  6. - **working_mode** (str, 可选):设置工作模式,目前支持'local'/'client'/'server'(默认为'local')。
  7. -'local',用于非分布式训练场景。
  8. -'client',用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。
  9. -'server',用于分布式训练场景。服务器加载数据并可供客户端使用。
  10. - **hostname** (str, 可选):图数据集服务器的主机名。该参数仅在工作模式设置为'client'或'server'时有效(默认为'127.0.0.1')。
  11. - **port** (int, 可选):图数据服务器的端口,取值范围为1024-65535。此参数仅当工作模式设置为'client'或'server'(默认为50051)时有效。
  12. - **num_client** (int, 可选):期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为'server'时有效(默认为1)。
  13. - **auto_shutdown** (bool, 可选):当工作模式设置为'server'时有效。当连接的客户端数量达到num_client,且没有客户端正在连接时,服务器将自动退出(默认为True)。
  14. **示例:**
  15. >>> graph_dataset_dir = "/path/to/graph_dataset_file"
  16. >>> graph_dataset = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
  17. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  18. >>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[1])
  19. get_all_edges(edge_type)
  20. 获取图的所有边。
  21. **参数:**
  22. edge_type (int):指定边的类型。
  23. **返回:**
  24. numpy.ndarray,包含边的数组。
  25. **示例:**
  26. >>> edges = graph_dataset.get_all_edges(edge_type=0)
  27. **异常:**
  28. - **TypeError**:参数`edge_type`的类型不为整型。
  29. get_all_neighbors(node_list, neighbor_type, output_format=<OutputFormat.NORMAL: 0。
  30. 获取`node_list`所有节点的邻居,以`neighbor_type`类型返回。
  31. 格式的定义参见以下示例:1表示两个节点之间连接,0表示不连接。
  32. .. list-table:: 邻接矩阵
  33. :widths: 20 20 20 20 20
  34. :header-rows: 1
  35. * -
  36. - 0
  37. - 1
  38. - 2
  39. - 3
  40. * - 0
  41. - 0
  42. - 1
  43. - 0
  44. - 0
  45. * - 1
  46. - 0
  47. - 0
  48. - 1
  49. - 0
  50. * - 2
  51. - 1
  52. - 0
  53. - 0
  54. - 1
  55. * - 3
  56. - 1
  57. - 0
  58. - 0
  59. - 0
  60. .. list-table:: 普通格式
  61. :widths: 20 20 20 20 20
  62. :header-rows: 1
  63. * - src
  64. - 0
  65. - 1
  66. - 2
  67. - 3
  68. * - dst_0
  69. - 1
  70. - 2
  71. - 0
  72. - 1
  73. * - dst_1
  74. - -1
  75. - -1
  76. - 3
  77. - -1
  78. .. list-table:: COO格式
  79. :widths: 20 20 20 20 20 20
  80. :header-rows: 1
  81. * - src
  82. - 0
  83. - 1
  84. - 2
  85. - 2
  86. - 3
  87. * - dst
  88. - 1
  89. - 2
  90. - 0
  91. - 3
  92. - 1
  93. .. list-table:: CSR格式
  94. :widths: 40 20 20 20 20 20
  95. :header-rows: 1
  96. * - offsetTable
  97. - 0
  98. - 1
  99. - 2
  100. - 4
  101. -
  102. * - dstTable
  103. - 1
  104. - 2
  105. - 0
  106. - 3
  107. - 1
  108. **参数:**
  109. - **node_list** (Union[list, numpy.ndarray]):给定的节点列表。
  110. - **neighbor_type** (int):指定邻居节点的类型。
  111. - **output_format** (OutputFormat, 可选):输出存储格式(默认为mindspore.dataset.engine.OutputFormat.NORMAL)取值范围:[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。
  112. **返回:**
  113. 对于普通格式或COO格式,
  114. 将返回numpy.ndarray类型的数组表示邻居节点。
  115. 如果指定了CSR格式,将返回两个numpy.ndarray数组,第一个表示偏移表,第二个表示邻居节点。
  116. **示例:**
  117. >>> from mindspore.dataset.engine import OutputFormat
  118. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  119. >>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
  120. >>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
  121. ... output_format=OutputFormat.COO)
  122. >>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
  123. ... output_format=OutputFormat.CSR)
  124. **异常:**
  125. - **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
  126. - **TypeError**:参数`neighbor_type`的类型不为整型。
  127. get_all_nodes(node_type)
  128. 获取图中的所有节点。
  129. **参数:**
  130. - **node_type** (int):指定节点的类型。
  131. **返回:**
  132. numpy.ndarray,包含节点的数组。
  133. **示例:**
  134. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  135. **异常:**
  136. - **TypeError**:参数`node_type`的类型不为整型。
  137. get_edges_from_nodes(node_list)
  138. 从节点获取边。
  139. **参数:**
  140. - **node_list** (Union[list[tuple], numpy.ndarray]):含一个或多个图节点ID对的列表。
  141. **返回:**
  142. numpy.ndarray,含一个或多个边ID的数组。
  143. **示例:**
  144. >>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
  145. **异常:**
  146. - **TypeError**:参数`edge_list`的类型不为列表或numpy.ndarray。
  147. get_edge_feature(edge_list, feature_types)
  148. 获取`edge_list`列表中边的特征,以`feature_types`类型返回。
  149. **参数:**
  150. - **edge_list** (Union[list, numpy.ndarray]):包含边的列表。
  151. - **feature_types** (Union[list, numpy.ndarray]):包含给定特征类型的列表。
  152. **返回:**
  153. numpy.ndarray,包含特征的数组。
  154. **示例:**
  155. >>> edges = graph_dataset.get_all_edges(edge_type=0)
  156. >>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=[1])
  157. **异常:**
  158. - **TypeError**:参数`edge_list`的类型不为列表或numpy.ndarray。
  159. - **TypeError**:参数`feature_types`的类型不为列表或numpy.ndarray。
  160. get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
  161. 获取`node_list`列表中节所有点的负样本邻居,以`neg_neighbor_type`类型返回。
  162. **参数:**
  163. - **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
  164. - **neg_neighbor_num** (int):采样的邻居数量。
  165. - **neg_neighbor_type** (int):指定负样本邻居的类型。
  166. **返回:**
  167. numpy.ndarray,包含邻居的数组。
  168. **示例:**
  169. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  170. >>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
  171. ... neg_neighbor_type=2)
  172. **异常:**
  173. - **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
  174. - **TypeError**:参数`neg_neighbor_num`的类型不为整型。
  175. - **TypeError**:参数`neg_neighbor_type`的类型不为整型。
  176. get_nodes_from_edges(edge_list)
  177. 从图中的边获取节点。
  178. **参数:**
  179. - **edge_list** (Union[list, numpy.ndarray]):包含边的列表。
  180. **返回:**
  181. numpy.ndarray,包含节点的数组。
  182. **异常:**
  183. TypeError:参数`edge_list`不为列表或ndarray。
  184. get_node_feature(node_list, feature_types)
  185. 获取`node_list`中节点的特征,以`feature_types`类型返回。
  186. **参数:**
  187. - **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
  188. - **feature_types** (Union[list, numpy.ndarray]):指定特征的类型。
  189. **返回:**
  190. numpy.ndarray,包含特征的数组。
  191. **示例:**
  192. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  193. >>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[2, 3])
  194. **异常:**
  195. - **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
  196. - **TypeError**:参数`feature_types`的类型不为列表或numpy.ndarray。
  197. get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=<SamplingStrategy.RANDOM: 0>)
  198. 获取已采样邻居信息。此API支持多跳邻居采样。即将上一次采样结果作为下一跳采样的输入,最多允许6跳。
  199. 采样结果平铺成列表,格式为[input node, 1-hop sampling result, 2-hop samling result ...]
  200. **参数:**
  201. - **node_list** (Union[list, numpy.ndarray]):包含节点的列表。
  202. - **neighbor_nums** (Union[list, numpy.ndarray]):每跳采样的邻居数。
  203. - **neighbor_types** (Union[list, numpy.ndarray]):每跳采样的邻居类型。
  204. - **strategy** (SamplingStrategy, 可选):采样策略(默认为mindspore.dataset.engine.SamplingStrategy.RANDOM)。取值范围:[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。
  205. - SamplingStrategy.RANDOM,随机抽样,带放回采样。
  206. - SamplingStrategy.EDGE_WEIGHT,以边缘权重为概率进行采样。
  207. **返回:**
  208. numpy.ndarray,包含邻居的数组。
  209. **示例:**
  210. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  211. >>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
  212. ... neighbor_types=[2, 1])
  213. **异常:**
  214. - **TypeError**:参数`node_list`的类型不为列表或numpy.ndarray。
  215. - **TypeError**:参数`neighbor_nums`的类型不为列表或numpy.ndarray。
  216. - **TypeError**:参数`neighbor_types`的类型不为列表或numpy.ndarray。
  217. graph_info()
  218. 获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。
  219. **返回:**
  220. dict,图的元信息。键为node_num、node_type、node_feature_type、edge_num、edge_type、和edge_feature_type。
  221. random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1)
  222. 在节点中的随机游走。
  223. **参数:**
  224. - **target_nodes** (list[int]):随机游走中的起始节点列表。
  225. - **meta_path** (list[int]):每个步长的节点类型。
  226. - **step_home_param** (float, 可选):返回node2vec算法中的超参(默认为1.0)。
  227. - **step_away_param** (float, 可选):node2vec算法中的in和out超参(默认为1.0)。
  228. - **default_node** (int, 可选):如果找不到更多邻居,则为默认节点(默认值为-1,表示不给定节点)。
  229. **返回:**
  230. numpy.ndarray,包含节点的数组。
  231. **示例:**
  232. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  233. >>> walks = graph_dataset.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])
  234. **异常:**
  235. - **TypeError**:参数`target_nodes`的类型不为列表或numpy.ndarray。
  236. - **TypeError**:参数`meta_path`的类型不为列表或numpy.ndarray。