You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

aggregator.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Aggregator."""
  16. import mindspore.nn as nn
  17. from mindspore import Tensor, Parameter
  18. from mindspore._checkparam import Validator
  19. from mindspore._extends import cell_attr_register
  20. from mindspore.common.initializer import initializer
  21. from mindspore.nn.layer.activation import get_activation
  22. from mindspore.ops import functional as F
  23. from mindspore.ops import operations as P
  24. class GNNFeatureTransform(nn.Cell):
  25. r"""
  26. The GNN featuren transform layer for input.
  27. Applies linear transformation for the input feature. This layer implements the operation as:
  28. .. math::
  29. \text{outputs} = \text{inputs} * \text{kernel} + \text{bias},
  30. where :math:`\text{activation}` is the activation function passed as the activation
  31. argument (if passed in),:math:`\text{activation}` is a weight matrix with the same
  32. data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
  33. with the same data type as the inputs created by the layer (only if has_bias is True).
  34. Args:
  35. in_channels (int): The number of channels in the input space.
  36. out_channels (int): The number of channels in the output space.
  37. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  38. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  39. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  40. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  41. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  42. Raises:
  43. ValueError: If weight_init or bias_init shape is incorrect.
  44. Inputs:
  45. - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
  46. where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
  47. size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`.
  48. Outputs:
  49. Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
  50. Examples:
  51. >>> net = nn.GNNFeatureTransform(3, 4)
  52. >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
  53. >>> net(input)
  54. [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
  55. [ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
  56. """
  57. @cell_attr_register
  58. def __init__(self,
  59. in_channels,
  60. out_channels,
  61. weight_init='normal',
  62. bias_init='zeros',
  63. has_bias=True):
  64. super(GNNFeatureTransform, self).__init__()
  65. self.in_channels = Validator.check_positive_int(in_channels)
  66. self.out_channels = Validator.check_positive_int(out_channels)
  67. self.has_bias = Validator.check_bool(has_bias)
  68. if isinstance(weight_init, Tensor):
  69. if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
  70. weight_init.shape[1] != in_channels:
  71. raise ValueError("weight_init shape error")
  72. self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
  73. if self.has_bias:
  74. if isinstance(bias_init, Tensor):
  75. if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
  76. raise ValueError("bias_init shape error")
  77. self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
  78. self.matmul = P.MatMul(transpose_b=True)
  79. self.bias_add = P.BiasAdd()
  80. def construct(self, x):
  81. tensor_shape = F.shape(x)
  82. input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2]))
  83. output = self.matmul(input_feature, self.weight)
  84. if self.has_bias:
  85. output = self.bias_add(output, self.bias)
  86. output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels))
  87. return output
  88. def extend_repr(self):
  89. str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
  90. .format(self.in_channels, self.out_channels, self.weight, self.has_bias)
  91. if self.has_bias:
  92. str_info = str_info + ', bias={}'.format(self.bias)
  93. return str_info
  94. class _BaseAggregator(nn.Cell):
  95. """
  96. Base Aggregator of GNN
  97. Args:
  98. feature_in_dim (int): Node or edge input feature dim.
  99. feature_out_dim (int): Node or edge outpout feature dim.
  100. use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
  101. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  102. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  103. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  104. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  105. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  106. dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
  107. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
  108. Examples:
  109. >>> class MyAggregator(_BaseAggregator):
  110. >>> def __init__(self):
  111. >>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim)
  112. >>> self.reduce_mean = P.ReduceSum()
  113. >>>
  114. >>> def construct(self, x):
  115. >>> return self.reduce_mean(x, 1)
  116. """
  117. def __init__(self,
  118. feature_in_dim,
  119. feature_out_dim,
  120. use_fc=True,
  121. weight_init="normal",
  122. bias_init="zeros",
  123. has_bias=True,
  124. dropout_ratio=None,
  125. activation=None):
  126. super(_BaseAggregator, self).__init__()
  127. self.in_dim = feature_in_dim
  128. self.out_dim = feature_out_dim
  129. self.use_fc = use_fc
  130. if self.use_fc:
  131. self.weight_init = weight_init
  132. self.bias_init = bias_init
  133. self.has_bias = has_bias
  134. self.fc = GNNFeatureTransform(self.in_dim,
  135. self.out_dim,
  136. weight_init=self.weight_init,
  137. bias_init=self.bias_init,
  138. has_bias=self.has_bias)
  139. self.dropout_ratio = dropout_ratio
  140. if self.dropout_ratio is not None:
  141. self.dropout = nn.Dropout(keep_prob=self.dropout_ratio)
  142. self.dropout_flag = self.dropout_ratio is not None
  143. self.activation = get_activation(activation)
  144. self.activation_flag = self.activation is not None
  145. def construct(self, **kward):
  146. """Must be overridden by all subclasses."""
  147. raise NotImplementedError
  148. class MeanAggregator(_BaseAggregator):
  149. """
  150. Mean Aggregator of GNN
  151. Args:
  152. feature_in_dim (int): Node or edge input feature dim.
  153. feature_out_dim (int): Node or edge outpout feature dim.
  154. use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
  155. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  156. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  157. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  158. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  159. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  160. dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
  161. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
  162. Examples:
  163. >>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5)
  164. >>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32))
  165. >>> output = net(input_data)
  166. """
  167. def __init__(self,
  168. feature_in_dim,
  169. feature_out_dim,
  170. use_fc=True,
  171. weight_init="normal",
  172. bias_init="zeros",
  173. has_bias=True,
  174. dropout_ratio=None,
  175. activation=None):
  176. super(MeanAggregator, self).__init__(
  177. feature_in_dim,
  178. feature_out_dim,
  179. use_fc,
  180. weight_init,
  181. bias_init,
  182. has_bias,
  183. dropout_ratio,
  184. activation)
  185. self.reduce_mean = P.ReduceMean(keep_dims=False)
  186. def construct(self, input_feature):
  187. if self.use_fc:
  188. input_feature = self.fc(input_feature)
  189. if self.dropout_flag:
  190. input_feature = self.dropout(input_feature)
  191. if self.activation_flag:
  192. input_feature = self.activation(input_feature)
  193. output_feature = self.reduce_mean(input_feature, 1)
  194. return output_feature
  195. class AttentionHead(nn.Cell):
  196. """
  197. Attention Head for Graph Attention Networks.
  198. Args:
  199. in_channel (int): The number of input channel, input feature dim.
  200. out_channel (int): The number of output channel, output feature dim.
  201. in_drop_ratio (float): Input feature dropout ratio, default 0.0.
  202. coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
  203. residual (bool): Whether to use residual connection, default False.
  204. coef_activation (Cell): The attention coefficient activation function,
  205. default nn.LeakyReLU().
  206. activation (Cell): The output activation function, default nn.ELU().
  207. Inputs:
  208. - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
  209. - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
  210. Examples:
  211. >>> head = AttentionHead(1433,
  212. 8,
  213. in_drop_ratio=0.6,
  214. coef_drop_ratio=0.6,
  215. residual=False)
  216. >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32))
  217. >>> output = net(input_data)
  218. """
  219. def __init__(self,
  220. in_channel,
  221. out_channel,
  222. in_drop_ratio=0.0,
  223. coef_drop_ratio=0.0,
  224. residual=False,
  225. coef_activation=nn.LeakyReLU(),
  226. activation=nn.ELU()):
  227. super(AttentionHead, self).__init__()
  228. self.in_channel = Validator.check_positive_int(in_channel)
  229. self.out_channel = Validator.check_positive_int(out_channel)
  230. self.in_drop_ratio = in_drop_ratio
  231. self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
  232. self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
  233. self.feature_transform = GNNFeatureTransform(
  234. in_channels=self.in_channel,
  235. out_channels=self.out_channel,
  236. has_bias=False)
  237. self.f_1_transform = GNNFeatureTransform(
  238. in_channels=self.out_channel,
  239. out_channels=1)
  240. self.f_2_transform = GNNFeatureTransform(
  241. in_channels=self.out_channel,
  242. out_channels=1)
  243. self.softmax = nn.Softmax()
  244. self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
  245. self.batch_matmul = P.BatchMatMul()
  246. self.bias_add = P.BiasAdd()
  247. self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
  248. self.residual = Validator.check_bool(residual)
  249. if self.residual:
  250. if in_channel != out_channel:
  251. self.residual_transform_flag = True
  252. self.residual_transform = GNNFeatureTransform(
  253. in_channels=self.in_channel,
  254. out_channels=self.out_channel)
  255. else:
  256. self.residual_transform = None
  257. self.coef_activation = coef_activation
  258. self.activation = activation
  259. def construct(self, input_feature, bias_mat):
  260. input_feature = self.in_drop(input_feature)
  261. feature = self.feature_transform(input_feature)
  262. # self attention following the author
  263. f_1 = self.f_1_transform(feature)
  264. f_2 = self.f_2_transform(feature)
  265. logits = f_1 + P.Transpose()(f_2, (0, 2, 1))
  266. logits = self.coef_activation(logits) + bias_mat
  267. coefs = self.softmax(logits)
  268. coefs = self.coef_drop(coefs)
  269. feature = self.in_drop_2(feature)
  270. ret = self.batch_matmul(coefs, feature)
  271. ret = P.Squeeze(0)(ret)
  272. ret = self.bias_add(ret, self.bias)
  273. ret = P.ExpandDims()(ret, 0)
  274. # residual connection
  275. if self.residual:
  276. if self.residual_transform_flag:
  277. res = self.residual_transform(input_feature)
  278. ret = ret + res
  279. else:
  280. ret = ret + input_feature
  281. # activation
  282. if self.activation is not None:
  283. ret = self.activation(ret)
  284. return ret
  285. class AttentionAggregator(nn.Cell):
  286. """
  287. Attention Head for Graph Attention Networks,can be regarded as one
  288. GAT layer.
  289. Args:
  290. in_channel (int): Input channel.
  291. out_channel (int): Output channel.
  292. num_heads (int): Number of attention heads for this layer, default 1.
  293. in_drop_ratio (float): Input feature dropout ratio, default 0.0.
  294. coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
  295. activation (Cell): The output activation function, default nn.ELU().
  296. residual (bool): Whether to use residual connection, default False.
  297. output_transform (str['concat', 'sum']): output transform for a layer,
  298. default 'concat'
  299. Inputs:
  300. - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
  301. - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
  302. Examples:
  303. >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
  304. >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
  305. >>> net = AttentionAggregator(1433,
  306. 8,
  307. 8)
  308. >>> net(input_data, biases)
  309. """
  310. def __init__(self,
  311. in_channels,
  312. out_channels,
  313. num_heads=1,
  314. in_drop=0.0,
  315. coef_drop=0.0,
  316. activation=nn.ELU(),
  317. residual=False,
  318. output_transform='concat'):
  319. super(AttentionAggregator, self).__init__()
  320. self.num_heads = num_heads
  321. self.attns = []
  322. for _ in range(num_heads):
  323. self.attns.append(AttentionHead(in_channels,
  324. out_channels,
  325. in_drop_ratio=in_drop,
  326. coef_drop_ratio=coef_drop,
  327. activation=activation,
  328. residual=residual))
  329. self.attns = nn.layer.CellList(self.attns)
  330. if output_transform == 'concat':
  331. self.out_trans = P.Concat(-1)
  332. elif output_transform == 'sum':
  333. self.out_trans = P.AddN()
  334. else:
  335. raise ValueError
  336. def construct(self, input_data, bias_mat):
  337. res = ()
  338. for i in range(self.num_heads):
  339. res += (self.attns[i](input_data, bias_mat),)
  340. return self.out_trans(res)