You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.GraphData.rst 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. mindspore.dataset.GraphData
  2. ===========================
  3. .. py:class:: mindspore.dataset.GraphData(dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True)
  4. 从共享文件和数据库中读取用于GNN训练的图数据集。
  5. **参数:**
  6. - **dataset_file** (str) - 数据集文件路径。
  7. - **num_parallel_workers** (int, 可选) - 读取数据的工作线程数(默认为None)。
  8. - **working_mode** (str, 可选) - 设置工作模式,目前支持'local'/'client'/'server'(默认为'local')。
  9. - **local**:用于非分布式训练场景。
  10. - **client**:用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。
  11. - **server**:用于分布式训练场景。服务器加载数据并可供客户端使用。
  12. - **hostname** (str, 可选) - 图数据集服务器的主机名。该参数仅在工作模式设置为 'client' 或 'server' 时有效(默认为'127.0.0.1')。
  13. - **port** (int, 可选) - 图数据服务器的端口,取值范围为1024-65535。此参数仅当工作模式设置为 'client' 或 'server' (默认为50051)时有效。
  14. - **num_client** (int, 可选) - 期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为 'server' 时有效(默认为1)。
  15. - **auto_shutdown** (bool, 可选) - 当工作模式设置为 'server' 时有效。当连接的客户端数量达到 `num_client` ,且没有客户端正在连接时,服务器将自动退出(默认为True)。
  16. **样例:**
  17. >>> graph_dataset_dir = "/path/to/graph_dataset_file"
  18. >>> graph_dataset = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
  19. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  20. >>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[1])
  21. .. py:method:: get_all_edges(edge_type)
  22. 获取图的所有边。
  23. **参数:**
  24. - **edge_type** (int) - 指定边的类型。
  25. **返回:**
  26. numpy.ndarray,包含边的数组。
  27. **样例:**
  28. >>> edges = graph_dataset.get_all_edges(edge_type=0)
  29. **异常:**
  30. **TypeError**:参数 `edge_type` 的类型不为整型。
  31. .. py:method:: get_all_neighbors(node_list, neighbor_type, output_format=<OutputFormat.NORMAL: 0。
  32. 获取 `node_list` 所有节点的邻居,以 `neighbor_type` 类型返回。格式的定义参见以下示例:1表示两个节点之间连接,0表示不连接。
  33. .. list-table:: 邻接矩阵
  34. :widths: 20 20 20 20 20
  35. :header-rows: 1
  36. * -
  37. - 0
  38. - 1
  39. - 2
  40. - 3
  41. * - 0
  42. - 0
  43. - 1
  44. - 0
  45. - 0
  46. * - 1
  47. - 0
  48. - 0
  49. - 1
  50. - 0
  51. * - 2
  52. - 1
  53. - 0
  54. - 0
  55. - 1
  56. * - 3
  57. - 1
  58. - 0
  59. - 0
  60. - 0
  61. .. list-table:: 普通格式
  62. :widths: 20 20 20 20 20
  63. :header-rows: 1
  64. * - src
  65. - 0
  66. - 1
  67. - 2
  68. - 3
  69. * - dst_0
  70. - 1
  71. - 2
  72. - 0
  73. - 1
  74. * - dst_1
  75. - -1
  76. - -1
  77. - 3
  78. - -1
  79. .. list-table:: COO格式
  80. :widths: 20 20 20 20 20 20
  81. :header-rows: 1
  82. * - src
  83. - 0
  84. - 1
  85. - 2
  86. - 2
  87. - 3
  88. * - dst
  89. - 1
  90. - 2
  91. - 0
  92. - 3
  93. - 1
  94. .. list-table:: CSR格式
  95. :widths: 40 20 20 20 20 20
  96. :header-rows: 1
  97. * - offsetTable
  98. - 0
  99. - 1
  100. - 2
  101. - 4
  102. -
  103. * - dstTable
  104. - 1
  105. - 2
  106. - 0
  107. - 3
  108. - 1
  109. **参数:**
  110. - **node_list** (Union[list, numpy.ndarray]) - 给定的节点列表。
  111. - **neighbor_type** (int) - 指定邻居节点的类型。
  112. - **output_format** (OutputFormat, 可选) - 输出存储格式(默认为mindspore.dataset.engine.OutputFormat.NORMAL)取值范围:[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。
  113. **返回:**
  114. 对于普通格式或COO格式,将返回numpy.ndarray类型的数组表示邻居节点。如果指定了CSR格式,将返回两个numpy.ndarray数组,第一个表示偏移表,第二个表示邻居节点。
  115. **样例:**
  116. >>> from mindspore.dataset.engine import OutputFormat
  117. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  118. >>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
  119. >>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
  120. ... output_format=OutputFormat.COO)
  121. >>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
  122. ... output_format=OutputFormat.CSR)
  123. **异常:**
  124. - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
  125. - **TypeError** - 参数 `neighbor_type` 的类型不为整型。
  126. .. py:method:: get_all_nodes(node_type)
  127. 获取图中的所有节点。
  128. **参数:**
  129. - **node_type** (int) - 指定节点的类型。
  130. **返回:**
  131. numpy.ndarray,包含节点的数组。
  132. **样例:**
  133. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  134. **异常:**
  135. **TypeError**:参数 `node_type` 的类型不为整型。
  136. .. py:method:: get_edges_from_nodes(node_list)
  137. 从节点获取边。
  138. **参数:**
  139. - **node_list** (Union[list[tuple], numpy.ndarray]) - 含一个或多个图节点ID对的列表。
  140. **返回:**
  141. numpy.ndarray,含一个或多个边ID的数组。
  142. **示例:**
  143. >>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
  144. **异常:**
  145. **TypeError**:参数 `edge_list` 的类型不为列表或numpy.ndarray。
  146. .. py:method:: get_edge_feature(edge_list, feature_types)
  147. 获取 `edge_list` 列表中边的特征,以 `feature_types` 类型返回。
  148. **参数:**
  149. - **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
  150. - **feature_types** (Union[list, numpy.ndarray]) - 包含给定特征类型的列表。
  151. **返回:**
  152. numpy.ndarray,包含特征的数组。
  153. **样例:**
  154. >>> edges = graph_dataset.get_all_edges(edge_type=0)
  155. >>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=[1])
  156. **异常:**
  157. - **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
  158. - **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
  159. .. py:method:: get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
  160. 获取 `node_list` 列表中节所有点的负样本邻居,以 `neg_neighbor_type` 类型返回。
  161. **参数:**
  162. - **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
  163. - **neg_neighbor_num** (int) - 采样的邻居数量。
  164. - **neg_neighbor_type** (int) - 指定负样本邻居的类型。
  165. **返回:**
  166. numpy.ndarray,包含邻居的数组。
  167. **样例:**
  168. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  169. >>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
  170. ... neg_neighbor_type=2)
  171. **异常:**
  172. - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
  173. - **TypeError** - 参数 `neg_neighbor_num` 的类型不为整型。
  174. - **TypeError** - 参数 `neg_neighbor_type` 的类型不为整型。
  175. .. py:method:: get_nodes_from_edges(edge_list)
  176. 从图中的边获取节点。
  177. **参数:**
  178. - **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
  179. **返回:**
  180. numpy.ndarray,包含节点的数组。
  181. **异常:**
  182. **TypeError:** 参数 `edge_list` 不为列表或ndarray。
  183. .. py:method:: get_node_feature(node_list, feature_types)
  184. 获取 `node_list` 中节点的特征,以 `feature_types` 类型返回。
  185. **参数:**
  186. - **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
  187. - **feature_types** (Union[list, numpy.ndarray]) - 指定特征的类型。
  188. **返回:**
  189. numpy.ndarray,包含特征的数组。
  190. **示例:**
  191. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  192. >>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[2, 3])
  193. **异常:**
  194. - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
  195. - **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
  196. .. py:method:: get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=<SamplingStrategy.RANDOM: 0>)
  197. 获取已采样邻居信息。此API支持多跳邻居采样。即将上一次采样结果作为下一跳采样的输入,最多允许6跳。采样结果平铺成列表,格式为[input node, 1-hop sampling result, 2-hop samling result ...]
  198. **参数:**
  199. - **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
  200. - **neighbor_nums** (Union[list, numpy.ndarray]) - 每跳采样的邻居数。
  201. - **neighbor_types** (Union[list, numpy.ndarray]) - 每跳采样的邻居类型。
  202. - **strategy** (SamplingStrategy, 可选) - 采样策略(默认为mindspore.dataset.engine.SamplingStrategy.RANDOM)。取值范围:[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。
  203. - **SamplingStrategy.RANDOM**:随机抽样,带放回采样。
  204. - **SamplingStrategy.EDGE_WEIGHT**:以边缘权重为概率进行采样。
  205. **返回:**
  206. numpy.ndarray,包含邻居的数组。
  207. *样例:**
  208. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  209. >>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
  210. ... neighbor_types=[2, 1])
  211. **异常:**
  212. - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
  213. - **TypeError** - 参数 `neighbor_nums` 的类型不为列表或numpy.ndarray。
  214. - **TypeError** - 参数 `neighbor_types` 的类型不为列表或numpy.ndarray。
  215. .. py:method:: graph_info()
  216. 获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。
  217. **返回:**
  218. dict,图的元信息。键为 `node_num` 、 `node_type` 、 `node_feature_type` 、 `edge_num` 、 `edge_type` 、和 `edge_feature_type` 。
  219. .. py:method:: random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1)
  220. 在节点中的随机游走。
  221. **参数:**
  222. - **target_nodes** (list[int]) - 随机游走中的起始节点列表。
  223. - **meta_path** (list[int]) - 每个步长的节点类型。
  224. - **step_home_param** (float, 可选) - 返回node2vec算法中的超参(默认为1.0)。
  225. - **step_away_param** (float, 可选) - node2vec算法中的in和out超参(默认为1.0)。
  226. - **default_node** (int, 可选) - 如果找不到更多邻居,则为默认节点(默认值为-1,表示不给定节点)。
  227. **返回:**
  228. numpy.ndarray,包含节点的数组。
  229. **示例:**
  230. >>> nodes = graph_dataset.get_all_nodes(node_type=1)
  231. >>> walks = graph_dataset.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])
  232. **异常:**
  233. - **TypeError** - 参数 `target_nodes` 的类型不为列表或numpy.ndarray。
  234. - **TypeError** - 参数 `meta_path` 的类型不为列表或numpy.ndarray。