You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

aggregator.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Aggregator."""
  16. import mindspore.nn as nn
  17. from mindspore import Tensor, Parameter
  18. from mindspore._checkparam import Validator
  19. from mindspore._extends import cell_attr_register
  20. from mindspore.common.initializer import initializer
  21. from mindspore.nn.layer.activation import get_activation
  22. from mindspore.ops import functional as F
  23. from mindspore.ops import operations as P
  24. class GNNFeatureTransform(nn.Cell):
  25. r"""
  26. The GNN featuren transform layer for input.
  27. Applies linear transformation for the input feature. This layer implements the operation as:
  28. .. math::
  29. \text{outputs} = \text{inputs} * \text{kernel} + \text{bias},
  30. where :math:`\text{activation}` is the activation function passed as the activation
  31. argument (if passed in),:math:`\text{activation}` is a weight matrix with the same
  32. data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
  33. with the same data type as the inputs created by the layer (only if has_bias is True).
  34. Args:
  35. in_channels (int): The number of channels in the input space.
  36. out_channels (int): The number of channels in the output space.
  37. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  38. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  39. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  40. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  41. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  42. Raises:
  43. ValueError: If weight_init or bias_init shape is incorrect.
  44. Inputs:
  45. - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*B, N, C)`,
  46. where :math:`*B` represents the batch size which can be multidimensional, :math:`N` and :math:`C` are the
  47. size of the last two dimensions. If `transpose_a` is True, its shape should be :math:`(*B, C, N)`.
  48. Outputs:
  49. Tensor, the shape of the output tensor is :math:`(*B, N, M)`.
  50. Examples:
  51. >>> net = nn.GNNFeatureTransform(3, 4)
  52. >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
  53. >>> net(input)
  54. [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
  55. [ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
  56. """
  57. @cell_attr_register
  58. def __init__(self,
  59. in_channels,
  60. out_channels,
  61. weight_init='normal',
  62. bias_init='zeros',
  63. has_bias=True):
  64. super(GNNFeatureTransform, self).__init__()
  65. self.in_channels = Validator.check_positive_int(in_channels)
  66. self.out_channels = Validator.check_positive_int(out_channels)
  67. self.has_bias = Validator.check_bool(has_bias)
  68. if isinstance(weight_init, Tensor):
  69. if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
  70. weight_init.shape[1] != in_channels:
  71. raise ValueError("weight_init shape error")
  72. self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
  73. if self.has_bias:
  74. if isinstance(bias_init, Tensor):
  75. if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
  76. raise ValueError("bias_init shape error")
  77. self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
  78. self.matmul = P.MatMul(transpose_b=True)
  79. self.bias_add = P.BiasAdd()
  80. def construct(self, x):
  81. tensor_shape = F.shape(x)
  82. input_feature = F.reshape(x, (tensor_shape[0] * tensor_shape[1], tensor_shape[2]))
  83. output = self.matmul(input_feature, self.weight)
  84. if self.has_bias:
  85. output = self.bias_add(output, self.bias)
  86. output = F.reshape(output, (tensor_shape[0], tensor_shape[1], self.out_channels))
  87. return output
  88. def extend_repr(self):
  89. s = 'in_channels={}, out_channels={}'.format(self.in_channels, self.out_channels)
  90. if self.has_bias:
  91. s += ', has_bias={}'.format(self.has_bias)
  92. return s
  93. class _BaseAggregator(nn.Cell):
  94. """
  95. Base Aggregator of GNN
  96. Args:
  97. feature_in_dim (int): Node or edge input feature dim.
  98. feature_out_dim (int): Node or edge outpout feature dim.
  99. use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
  100. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  101. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  102. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  103. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  104. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  105. dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
  106. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
  107. Examples:
  108. >>> class MyAggregator(_BaseAggregator):
  109. >>> def __init__(self):
  110. >>> super(MyAggregator, self).__init__(self, feature_in_dim, feature_out_dim)
  111. >>> self.reduce_mean = P.ReduceSum()
  112. >>>
  113. >>> def construct(self, x):
  114. >>> return self.reduce_mean(x, 1)
  115. """
  116. def __init__(self,
  117. feature_in_dim,
  118. feature_out_dim,
  119. use_fc=True,
  120. weight_init="normal",
  121. bias_init="zeros",
  122. has_bias=True,
  123. dropout_ratio=None,
  124. activation=None):
  125. super(_BaseAggregator, self).__init__()
  126. self.in_dim = feature_in_dim
  127. self.out_dim = feature_out_dim
  128. self.use_fc = use_fc
  129. if self.use_fc:
  130. self.weight_init = weight_init
  131. self.bias_init = bias_init
  132. self.has_bias = has_bias
  133. self.fc = GNNFeatureTransform(self.in_dim,
  134. self.out_dim,
  135. weight_init=self.weight_init,
  136. bias_init=self.bias_init,
  137. has_bias=self.has_bias)
  138. self.dropout_ratio = dropout_ratio
  139. if self.dropout_ratio is not None:
  140. self.dropout = nn.Dropout(keep_prob=self.dropout_ratio)
  141. self.dropout_flag = self.dropout_ratio is not None
  142. self.activation = get_activation(activation)
  143. self.activation_flag = self.activation is not None
  144. def construct(self, **kward):
  145. """Must be overridden by all subclasses."""
  146. raise NotImplementedError
  147. class MeanAggregator(_BaseAggregator):
  148. """
  149. Mean Aggregator of GNN
  150. Args:
  151. feature_in_dim (int): Node or edge input feature dim.
  152. feature_out_dim (int): Node or edge outpout feature dim.
  153. use_fc (bool): Specifies whether a linear transformation before message is aggregated. Default: True
  154. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  155. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  156. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  157. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  158. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  159. dropout_ratio (float): The keep rate of dropout layer, greater than 0 and less equal than 1. Default: None.
  160. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
  161. Examples:
  162. >>> net = MeanAggregator(32, 64, activation="relu", dropout=0.5)
  163. >>> input_data = Tensor(np.array(np.random.rand(32, 3, 32), dtypy=np.float32))
  164. >>> output = net(input_data)
  165. """
  166. def __init__(self,
  167. feature_in_dim,
  168. feature_out_dim,
  169. use_fc=True,
  170. weight_init="normal",
  171. bias_init="zeros",
  172. has_bias=True,
  173. dropout_ratio=None,
  174. activation=None):
  175. super(MeanAggregator, self).__init__(
  176. feature_in_dim,
  177. feature_out_dim,
  178. use_fc,
  179. weight_init,
  180. bias_init,
  181. has_bias,
  182. dropout_ratio,
  183. activation)
  184. self.reduce_mean = P.ReduceMean(keep_dims=False)
  185. def construct(self, input_feature):
  186. if self.use_fc:
  187. input_feature = self.fc(input_feature)
  188. if self.dropout_flag:
  189. input_feature = self.dropout(input_feature)
  190. if self.activation_flag:
  191. input_feature = self.activation(input_feature)
  192. output_feature = self.reduce_mean(input_feature, 1)
  193. return output_feature
  194. class AttentionHead(nn.Cell):
  195. """
  196. Attention Head for Graph Attention Networks.
  197. Args:
  198. in_channel (int): The number of input channel, input feature dim.
  199. out_channel (int): The number of output channel, output feature dim.
  200. in_drop_ratio (float): Input feature dropout ratio, default 0.0.
  201. coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
  202. residual (bool): Whether to use residual connection, default False.
  203. coef_activation (Cell): The attention coefficient activation function,
  204. default nn.LeakyReLU().
  205. activation (Cell): The output activation function, default nn.ELU().
  206. Inputs:
  207. - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
  208. - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
  209. Examples:
  210. >>> head = AttentionHead(1433,
  211. 8,
  212. in_drop_ratio=0.6,
  213. coef_drop_ratio=0.6,
  214. residual=False)
  215. >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtypy=np.float32))
  216. >>> output = net(input_data)
  217. """
  218. def __init__(self,
  219. in_channel,
  220. out_channel,
  221. in_drop_ratio=0.0,
  222. coef_drop_ratio=0.0,
  223. residual=False,
  224. coef_activation=nn.LeakyReLU(),
  225. activation=nn.ELU()):
  226. super(AttentionHead, self).__init__()
  227. self.in_channel = Validator.check_positive_int(in_channel)
  228. self.out_channel = Validator.check_positive_int(out_channel)
  229. self.in_drop_ratio = in_drop_ratio
  230. self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
  231. self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
  232. self.feature_transform = GNNFeatureTransform(
  233. in_channels=self.in_channel,
  234. out_channels=self.out_channel,
  235. has_bias=False)
  236. self.f_1_transform = GNNFeatureTransform(
  237. in_channels=self.out_channel,
  238. out_channels=1)
  239. self.f_2_transform = GNNFeatureTransform(
  240. in_channels=self.out_channel,
  241. out_channels=1)
  242. self.softmax = nn.Softmax()
  243. self.coef_drop = nn.Dropout(keep_prob=1 - coef_drop_ratio)
  244. self.batch_matmul = P.BatchMatMul()
  245. self.bias_add = P.BiasAdd()
  246. self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
  247. self.residual = Validator.check_bool(residual)
  248. if self.residual:
  249. if in_channel != out_channel:
  250. self.residual_transform_flag = True
  251. self.residual_transform = GNNFeatureTransform(
  252. in_channels=self.in_channel,
  253. out_channels=self.out_channel)
  254. else:
  255. self.residual_transform = None
  256. self.coef_activation = coef_activation
  257. self.activation = activation
  258. def construct(self, input_feature, bias_mat):
  259. input_feature = self.in_drop(input_feature)
  260. feature = self.feature_transform(input_feature)
  261. # self attention following the author
  262. f_1 = self.f_1_transform(feature)
  263. f_2 = self.f_2_transform(feature)
  264. logits = f_1 + P.Transpose()(f_2, (0, 2, 1))
  265. logits = self.coef_activation(logits) + bias_mat
  266. coefs = self.softmax(logits)
  267. coefs = self.coef_drop(coefs)
  268. feature = self.in_drop_2(feature)
  269. ret = self.batch_matmul(coefs, feature)
  270. ret = P.Squeeze(0)(ret)
  271. ret = self.bias_add(ret, self.bias)
  272. ret = P.ExpandDims()(ret, 0)
  273. # residual connection
  274. if self.residual:
  275. if self.residual_transform_flag:
  276. res = self.residual_transform(input_feature)
  277. ret = ret + res
  278. else:
  279. ret = ret + input_feature
  280. # activation
  281. if self.activation is not None:
  282. ret = self.activation(ret)
  283. return ret
  284. class AttentionAggregator(nn.Cell):
  285. """
  286. Attention Head for Graph Attention Networks,can be regarded as one
  287. GAT layer.
  288. Args:
  289. in_channel (int): Input channel.
  290. out_channel (int): Output channel.
  291. num_heads (int): Number of attention heads for this layer, default 1.
  292. in_drop_ratio (float): Input feature dropout ratio, default 0.0.
  293. coef_drop_ratio (float): Coefficient dropout ratio, default 0.0.
  294. activation (Cell): The output activation function, default nn.ELU().
  295. residual (bool): Whether to use residual connection, default False.
  296. output_transform (str['concat', 'sum']): output transform for a layer,
  297. default 'concat'
  298. Inputs:
  299. - **input_feature** (Tensor) - Tensor of shape : (batch_size, num_nodes, feature_dim).
  300. - **bias_mat** (Tensor) - Tensor of shape : (batch_size, num_nodes, num_nodes).
  301. Examples:
  302. >>> input_data = Tensor(np.array(np.random.rand(1, 2708, 1433), dtype=np.float32))
  303. >>> biases = Tensor(np.array(np.random.rand(1, 2708, 2708), dtype=np.float32))
  304. >>> net = AttentionAggregator(1433,
  305. 8,
  306. 8)
  307. >>> net(input_data, biases)
  308. """
  309. def __init__(self,
  310. in_channels,
  311. out_channels,
  312. num_heads=1,
  313. in_drop=0.0,
  314. coef_drop=0.0,
  315. activation=nn.ELU(),
  316. residual=False,
  317. output_transform='concat'):
  318. super(AttentionAggregator, self).__init__()
  319. self.num_heads = num_heads
  320. self.attns = []
  321. for _ in range(num_heads):
  322. self.attns.append(AttentionHead(in_channels,
  323. out_channels,
  324. in_drop_ratio=in_drop,
  325. coef_drop_ratio=coef_drop,
  326. activation=activation,
  327. residual=residual))
  328. self.attns = nn.layer.CellList(self.attns)
  329. if output_transform == 'concat':
  330. self.out_trans = P.Concat(-1)
  331. elif output_transform == 'sum':
  332. self.out_trans = P.AddN()
  333. else:
  334. raise ValueError
  335. def construct(self, input_data, bias_mat):
  336. res = ()
  337. for i in range(self.num_heads):
  338. res += (self.attns[i](input_data, bias_mat),)
  339. return self.out_trans(res)