You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gat.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Graph Attention Networks."""
  16. import mindspore.nn as nn
  17. from mindspore._checkparam import Validator
  18. from aggregator import AttentionAggregator
  19. class GAT(nn.Cell):
  20. """
  21. Graph Attention Network
  22. Args:
  23. ftr_dims (int): Initial feature dimensions.
  24. num_class (int): Num of class to identify.
  25. num_nodes (int): Num of nodes in this graph.
  26. hidden_units (list[int]): Num of hidden units at each layer.
  27. num_heads (list[int]): Num of heads at each layer.
  28. attn_drop (float): Drop out ratio of attention coefficient,
  29. default 0.0.
  30. ftr_drop (float): Drop out ratio of feature, default 0.0.
  31. activation (Cell): Activation Function for output layer, default
  32. nn.Elu().
  33. residual (bool): Whether to use residual connection between
  34. intermediate layers, default False.
  35. Examples:
  36. >>> ft_sizes = 1433
  37. >>> num_class = 7
  38. >>> num_nodes = 2708
  39. >>> hid_units = [8]
  40. >>> n_heads = [8, 1]
  41. >>> activation = nn.ELU()
  42. >>> residual = False
  43. >>> input_data = Tensor(
  44. np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
  45. >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
  46. >>> net = GAT(ft_sizes,
  47. num_class,
  48. num_nodes,
  49. hidden_units=hid_units,
  50. num_heads=n_heads,
  51. attn_drop=0.6,
  52. ftr_drop=0.6,
  53. activation=activation,
  54. residual=residual)
  55. >>> output = net(input_data, biases)
  56. """
  57. def __init__(self,
  58. ftr_dims,
  59. num_class,
  60. num_nodes,
  61. hidden_units,
  62. num_heads,
  63. attn_drop=0.0,
  64. ftr_drop=0.0,
  65. activation=nn.ELU(),
  66. residual=False):
  67. super(GAT, self).__init__()
  68. self.ftr_dims = Validator.check_positive_int(ftr_dims)
  69. self.num_class = Validator.check_positive_int(num_class)
  70. self.num_nodes = Validator.check_positive_int(num_nodes)
  71. self.hidden_units = hidden_units
  72. self.num_heads = num_heads
  73. self.attn_drop = attn_drop
  74. self.ftr_drop = ftr_drop
  75. self.activation = activation
  76. self.residual = Validator.check_bool(residual)
  77. self.layers = []
  78. # first layer
  79. self.layers.append(AttentionAggregator(
  80. self.ftr_dims,
  81. self.hidden_units[0],
  82. self.num_heads[0],
  83. self.ftr_drop,
  84. self.attn_drop,
  85. self.activation,
  86. residual=False))
  87. # intermediate layer
  88. for i in range(1, len(self.hidden_units)):
  89. self.layers.append(AttentionAggregator(
  90. self.hidden_units[i-1]*self.num_heads[i-1],
  91. self.hidden_units[i],
  92. self.num_heads[i],
  93. self.ftr_drop,
  94. self.attn_drop,
  95. self.activation,
  96. residual=self.residual))
  97. # output layer
  98. self.layers.append(AttentionAggregator(
  99. self.hidden_units[-1]*self.num_heads[-2],
  100. self.num_class,
  101. self.num_heads[-1],
  102. self.ftr_drop,
  103. self.attn_drop,
  104. activation=None,
  105. residual=False,
  106. output_transform='sum'))
  107. self.layers = nn.layer.CellList(self.layers)
  108. def construct(self, input_data, bias_mat):
  109. for cell in self.layers:
  110. input_data = cell(input_data, bias_mat)
  111. return input_data/self.num_heads[-1]