You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bgcf.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Architecture"""
  16. import os
  17. import numpy as np
  18. import pytest
  19. import mindspore.nn as nn
  20. from mindspore import Parameter, Tensor, context
  21. from mindspore.ops import operations as P
  22. from mindspore.common import dtype as mstype
  23. from mindspore.common.initializer import initializer
  24. from mindspore.train.serialization import export
  25. context.set_context(mode=context.PYNATIVE_MODE)
  26. class MeanConv(nn.Cell):
  27. def __init__(self,
  28. feature_in_dim,
  29. feature_out_dim,
  30. activation,
  31. dropout=0.2):
  32. super(MeanConv, self).__init__()
  33. self.out_weight = Parameter(
  34. initializer("XavierUniform", [feature_in_dim * 2, feature_out_dim], dtype=mstype.float32))
  35. if activation == "tanh":
  36. self.act = P.Tanh()
  37. elif activation == "relu":
  38. self.act = P.ReLU()
  39. else:
  40. raise ValueError("activation should be tanh or relu")
  41. self.cast = P.Cast()
  42. self.matmul = P.MatMul()
  43. self.concat = P.Concat(axis=1)
  44. self.reduce_mean = P.ReduceMean(keep_dims=False)
  45. self.dropout = nn.Dropout(keep_prob=1 - dropout)
  46. def construct(self, self_feature, neigh_feature):
  47. neigh_matrix = self.reduce_mean(neigh_feature, 1)
  48. neigh_matrix = self.dropout(neigh_matrix)
  49. output = self.concat((self_feature, neigh_matrix))
  50. output = self.act(self.matmul(output, self.out_weight))
  51. return output
  52. class AttenConv(nn.Cell):
  53. def __init__(self,
  54. feature_in_dim,
  55. feature_out_dim,
  56. dropout=0.2):
  57. super(AttenConv, self).__init__()
  58. self.out_weight = Parameter(
  59. initializer("XavierUniform", [feature_in_dim * 2, feature_out_dim], dtype=mstype.float32))
  60. self.cast = P.Cast()
  61. self.squeeze = P.Squeeze(1)
  62. self.concat = P.Concat(axis=1)
  63. self.expanddims = P.ExpandDims()
  64. self.softmax = P.Softmax(axis=-1)
  65. self.matmul = P.MatMul()
  66. self.matmul_3 = P.BatchMatMul()
  67. self.matmul_t = P.BatchMatMul(transpose_b=True)
  68. self.dropout = nn.Dropout(keep_prob=1 - dropout)
  69. def construct(self, self_feature, neigh_feature):
  70. query = self.expanddims(self_feature, 1)
  71. neigh_matrix = self.dropout(neigh_feature)
  72. score = self.matmul_t(query, neigh_matrix)
  73. score = self.softmax(score)
  74. atten_agg = self.matmul_3(score, neigh_matrix)
  75. atten_agg = self.squeeze(atten_agg)
  76. output = self.matmul(self.concat((atten_agg, self_feature)), self.out_weight)
  77. return output
  78. class BGCF(nn.Cell):
  79. def __init__(self,
  80. dataset_argv,
  81. architect_argv,
  82. activation,
  83. neigh_drop_rate,
  84. num_user,
  85. num_item,
  86. input_dim):
  87. super(BGCF, self).__init__()
  88. self.user_embed = Parameter(initializer("XavierUniform", [num_user, input_dim], dtype=mstype.float32))
  89. self.item_embed = Parameter(initializer("XavierUniform", [num_item, input_dim], dtype=mstype.float32))
  90. self.cast = P.Cast()
  91. self.tanh = P.Tanh()
  92. self.shape = P.Shape()
  93. self.split = P.Split(0, 2)
  94. self.gather = P.Gather()
  95. self.reshape = P.Reshape()
  96. self.concat_0 = P.Concat(0)
  97. self.concat_1 = P.Concat(1)
  98. (self.input_dim, self.num_user, self.num_item) = dataset_argv
  99. self.layer_dim = architect_argv
  100. self.gnew_agg_mean = MeanConv(self.input_dim, self.layer_dim,
  101. activation=activation, dropout=neigh_drop_rate[1])
  102. self.gnew_agg_mean.to_float(mstype.float16)
  103. self.gnew_agg_user = AttenConv(self.input_dim, self.layer_dim, dropout=neigh_drop_rate[2])
  104. self.gnew_agg_user.to_float(mstype.float16)
  105. self.gnew_agg_item = AttenConv(self.input_dim, self.layer_dim, dropout=neigh_drop_rate[2])
  106. self.gnew_agg_item.to_float(mstype.float16)
  107. self.user_feature_dim = self.input_dim
  108. self.item_feature_dim = self.input_dim
  109. self.final_weight = Parameter(
  110. initializer("XavierUniform", [self.input_dim * 3, self.input_dim * 3], dtype=mstype.float32))
  111. self.raw_agg_funcs_user = MeanConv(self.input_dim, self.layer_dim,
  112. activation=activation, dropout=neigh_drop_rate[0])
  113. self.raw_agg_funcs_user.to_float(mstype.float16)
  114. self.raw_agg_funcs_item = MeanConv(self.input_dim, self.layer_dim,
  115. activation=activation, dropout=neigh_drop_rate[0])
  116. self.raw_agg_funcs_item.to_float(mstype.float16)
  117. def construct(self,
  118. u_id,
  119. pos_item_id,
  120. neg_item_id,
  121. pos_users,
  122. pos_items,
  123. u_group_nodes,
  124. u_neighs,
  125. u_gnew_neighs,
  126. i_group_nodes,
  127. i_neighs,
  128. i_gnew_neighs,
  129. neg_group_nodes,
  130. neg_neighs,
  131. neg_gnew_neighs,
  132. neg_item_num):
  133. all_user_embed = self.gather(self.user_embed, self.concat_0((u_id, pos_users)), 0)
  134. u_self_matrix_at_layers = self.gather(self.user_embed, u_group_nodes, 0)
  135. u_neigh_matrix_at_layers = self.gather(self.item_embed, u_neighs, 0)
  136. u_output_mean = self.raw_agg_funcs_user(u_self_matrix_at_layers, u_neigh_matrix_at_layers)
  137. u_gnew_neighs_matrix = self.gather(self.item_embed, u_gnew_neighs, 0)
  138. u_output_from_gnew_mean = self.gnew_agg_mean(u_self_matrix_at_layers, u_gnew_neighs_matrix)
  139. u_output_from_gnew_att = self.gnew_agg_user(u_self_matrix_at_layers,
  140. self.concat_1((u_neigh_matrix_at_layers, u_gnew_neighs_matrix)))
  141. u_output = self.concat_1((u_output_mean, u_output_from_gnew_mean, u_output_from_gnew_att))
  142. all_user_rep = self.tanh(u_output)
  143. all_pos_item_embed = self.gather(self.item_embed, self.concat_0((pos_item_id, pos_items)), 0)
  144. i_self_matrix_at_layers = self.gather(self.item_embed, i_group_nodes, 0)
  145. i_neigh_matrix_at_layers = self.gather(self.user_embed, i_neighs, 0)
  146. i_output_mean = self.raw_agg_funcs_item(i_self_matrix_at_layers, i_neigh_matrix_at_layers)
  147. i_gnew_neighs_matrix = self.gather(self.user_embed, i_gnew_neighs, 0)
  148. i_output_from_gnew_mean = self.gnew_agg_mean(i_self_matrix_at_layers, i_gnew_neighs_matrix)
  149. i_output_from_gnew_att = self.gnew_agg_item(i_self_matrix_at_layers,
  150. self.concat_1((i_neigh_matrix_at_layers, i_gnew_neighs_matrix)))
  151. i_output = self.concat_1((i_output_mean, i_output_from_gnew_mean, i_output_from_gnew_att))
  152. all_pos_item_rep = self.tanh(i_output)
  153. neg_item_embed = self.gather(self.item_embed, neg_item_id, 0)
  154. neg_self_matrix_at_layers = self.gather(self.item_embed, neg_group_nodes, 0)
  155. neg_neigh_matrix_at_layers = self.gather(self.user_embed, neg_neighs, 0)
  156. neg_output_mean = self.raw_agg_funcs_item(neg_self_matrix_at_layers, neg_neigh_matrix_at_layers)
  157. neg_gnew_neighs_matrix = self.gather(self.user_embed, neg_gnew_neighs, 0)
  158. neg_output_from_gnew_mean = self.gnew_agg_mean(neg_self_matrix_at_layers, neg_gnew_neighs_matrix)
  159. neg_output_from_gnew_att = self.gnew_agg_item(neg_self_matrix_at_layers,
  160. self.concat_1(
  161. (neg_neigh_matrix_at_layers, neg_gnew_neighs_matrix)))
  162. neg_output = self.concat_1((neg_output_mean, neg_output_from_gnew_mean, neg_output_from_gnew_att))
  163. neg_output = self.tanh(neg_output)
  164. neg_output_shape = self.shape(neg_output)
  165. neg_item_rep = self.reshape(neg_output,
  166. (self.shape(neg_item_embed)[0], neg_item_num, neg_output_shape[-1]))
  167. return all_user_embed, all_user_rep, all_pos_item_embed, all_pos_item_rep, neg_item_embed, neg_item_rep
  168. class ForwardBGCF(nn.Cell):
  169. def __init__(self,
  170. network):
  171. super(ForwardBGCF, self).__init__()
  172. self.network = network
  173. def construct(self, users, items, neg_items, u_neighs, u_gnew_neighs, i_neighs, i_gnew_neighs):
  174. _, user_rep, _, item_rep, _, _, = self.network(users, items, neg_items, users, items, users,
  175. u_neighs, u_gnew_neighs, items, i_neighs, i_gnew_neighs,
  176. items, i_neighs, i_gnew_neighs, 1)
  177. return user_rep, item_rep
  178. @pytest.mark.level0
  179. @pytest.mark.platform_x86_ascend_training
  180. @pytest.mark.platform_arm_ascend_training
  181. @pytest.mark.env_onecard
  182. def test_export_bgcf():
  183. num_user, num_item = 7068, 3570
  184. network = BGCF([64, num_user, num_item], 64, "tanh",
  185. [0.0, 0.0, 0.0], num_user, num_item, 64)
  186. forward_net = ForwardBGCF(network)
  187. users = Tensor(np.zeros([num_user,]).astype(np.int32))
  188. items = Tensor(np.zeros([num_item,]).astype(np.int32))
  189. neg_items = Tensor(np.zeros([num_item, 1]).astype(np.int32))
  190. u_test_neighs = Tensor(np.zeros([num_user, 40]).astype(np.int32))
  191. u_test_gnew_neighs = Tensor(np.zeros([num_user, 20]).astype(np.int32))
  192. i_test_neighs = Tensor(np.zeros([num_item, 40]).astype(np.int32))
  193. i_test_gnew_neighs = Tensor(np.zeros([num_item, 20]).astype(np.int32))
  194. input_data = [users, items, neg_items, u_test_neighs, u_test_gnew_neighs, i_test_neighs, i_test_gnew_neighs]
  195. file_name = "bgcf"
  196. export(forward_net, *input_data, file_name=file_name, file_format="MINDIR")
  197. mindir_file = file_name + ".mindir"
  198. assert os.path.exists(mindir_file)
  199. os.remove(mindir_file)
  200. export(forward_net, *input_data, file_name=file_name, file_format="AIR")
  201. air_file = file_name + ".air"
  202. assert os.path.exists(air_file)
  203. os.remove(air_file)