You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ncf.py 12 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Neural Collaborative Filtering Model"""
  16. from mindspore import nn
  17. from mindspore import Tensor, Parameter, ParameterTuple
  18. from mindspore._checkparam import Validator as validator
  19. from mindspore.nn.layer.activation import get_activation
  20. import mindspore.common.dtype as mstype
  21. from mindspore.ops import operations as P
  22. from mindspore.common.initializer import initializer
  23. from mindspore.ops import functional as F
  24. from mindspore.ops import composite as C
  25. from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
  26. from mindspore.context import ParallelMode
  27. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  28. from src.lr_schedule import dynamic_lr
  29. class DenseLayer(nn.Cell):
  30. """
  31. Dense layer definition
  32. """
  33. def __init__(self,
  34. in_channels,
  35. out_channels,
  36. weight_init='normal',
  37. bias_init='zeros',
  38. has_bias=True,
  39. activation=None):
  40. super(DenseLayer, self).__init__()
  41. self.in_channels = validator.check_positive_int(in_channels)
  42. self.out_channels = validator.check_positive_int(out_channels)
  43. self.has_bias = validator.check_bool(has_bias)
  44. if isinstance(weight_init, Tensor):
  45. if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
  46. weight_init.shape()[1] != in_channels:
  47. raise ValueError("weight_init shape error")
  48. self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]))
  49. if self.has_bias:
  50. if isinstance(bias_init, Tensor):
  51. if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
  52. raise ValueError("bias_init shape error")
  53. self.bias = Parameter(initializer(bias_init, [out_channels]))
  54. self.matmul = P.MatMul(transpose_b=True)
  55. self.bias_add = P.BiasAdd()
  56. self.cast = P.Cast()
  57. self.activation = get_activation(activation)
  58. self.activation_flag = self.activation is not None
  59. def construct(self, x):
  60. """
  61. dense layer construct method
  62. """
  63. x = self.cast(x, mstype.float16)
  64. weight = self.cast(self.weight, mstype.float16)
  65. bias = self.cast(self.bias, mstype.float16)
  66. output = self.matmul(x, weight)
  67. if self.has_bias:
  68. output = self.bias_add(output, bias)
  69. if self.activation_flag:
  70. output = self.activation(output)
  71. output = self.cast(output, mstype.float32)
  72. return output
  73. def extend_repr(self):
  74. """A pretty print for Dense layer."""
  75. str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
  76. .format(self.in_channels, self.out_channels, self.weight, self.has_bias)
  77. if self.has_bias:
  78. str_info = str_info + ', bias={}'.format(self.bias)
  79. if self.activation_flag:
  80. str_info = str_info + ', activation={}'.format(self.activation)
  81. return str_info
  82. class NCFModel(nn.Cell):
  83. """
  84. Class for Neural Collaborative Filtering Model from paper " Neural Collaborative Filtering".
  85. """
  86. def __init__(self,
  87. num_users,
  88. num_items,
  89. num_factors,
  90. model_layers,
  91. mf_regularization,
  92. mlp_reg_layers,
  93. mf_dim):
  94. super(NCFModel, self).__init__()
  95. self.data_path = ""
  96. self.model_path = ""
  97. self.num_users = num_users
  98. self.num_items = num_items
  99. self.num_factors = num_factors
  100. self.model_layers = model_layers
  101. self.mf_regularization = mf_regularization
  102. self.mlp_reg_layers = mlp_reg_layers
  103. self.mf_dim = mf_dim
  104. self.num_layers = len(self.model_layers) # Number of layers in the MLP
  105. if self.model_layers[0] % 2 != 0:
  106. raise ValueError("The first layer size should be multiple of 2!")
  107. # Initializer for embedding layers
  108. self.embedding_initializer = "normal"
  109. self.embedding_user = nn.Embedding(
  110. self.num_users,
  111. self.num_factors + self.model_layers[0] // 2,
  112. embedding_table=self.embedding_initializer
  113. )
  114. self.embedding_item = nn.Embedding(
  115. self.num_items,
  116. self.num_factors + self.model_layers[0] // 2,
  117. embedding_table=self.embedding_initializer
  118. )
  119. self.mlp_dense1 = DenseLayer(in_channels=self.model_layers[0],
  120. out_channels=self.model_layers[1],
  121. activation="relu")
  122. self.mlp_dense2 = DenseLayer(in_channels=self.model_layers[1],
  123. out_channels=self.model_layers[2],
  124. activation="relu")
  125. # Logit dense layer
  126. self.logits_dense = DenseLayer(in_channels=self.model_layers[1],
  127. out_channels=1,
  128. weight_init="normal",
  129. activation=None)
  130. # ops definition
  131. self.mul = P.Mul()
  132. self.squeeze = P.Squeeze(axis=1)
  133. self.concat = P.Concat(axis=1)
  134. def construct(self, user_input, item_input):
  135. """
  136. NCF construct method.
  137. """
  138. # GMF part
  139. # embedding_layers
  140. embedding_user = self.embedding_user(user_input) # input: (256, 1) output: (256, 1, 16 + 32)
  141. embedding_item = self.embedding_item(item_input) # input: (256, 1) output: (256, 1, 16 + 32)
  142. mf_user_latent = self.squeeze(embedding_user)[:, :self.num_factors] # input: (256, 1, 16 + 32) output: (256, 16)
  143. mf_item_latent = self.squeeze(embedding_item)[:, :self.num_factors] # input: (256, 1, 16 + 32) output: (256, 16)
  144. # MLP part
  145. mlp_user_latent = self.squeeze(embedding_user)[:, self.mf_dim:] # input: (256, 1, 16 + 32) output: (256, 32)
  146. mlp_item_latent = self.squeeze(embedding_item)[:, self.mf_dim:] # input: (256, 1, 16 + 32) output: (256, 32)
  147. # Element-wise multiply
  148. mf_vector = self.mul(mf_user_latent, mf_item_latent) # input: (256, 16), (256, 16) output: (256, 16)
  149. # Concatenation of two latent features
  150. mlp_vector = self.concat((mlp_user_latent, mlp_item_latent)) # input: (256, 32), (256, 32) output: (256, 64)
  151. # MLP dense layers
  152. mlp_vector = self.mlp_dense1(mlp_vector) # input: (256, 64) output: (256, 32)
  153. mlp_vector = self.mlp_dense2(mlp_vector) # input: (256, 32) output: (256, 16)
  154. # # Concatenate GMF and MLP parts
  155. predict_vector = self.concat((mf_vector, mlp_vector)) # input: (256, 16), (256, 16) output: (256, 32)
  156. # Final prediction layer
  157. logits = self.logits_dense(predict_vector) # input: (256, 32) output: (256, 1)
  158. # Print model topology.
  159. return logits
  160. class NetWithLossClass(nn.Cell):
  161. """
  162. NetWithLossClass definition
  163. """
  164. def __init__(self, network):
  165. super(NetWithLossClass, self).__init__(auto_prefix=False)
  166. #self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
  167. self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
  168. self.network = network
  169. self.reducesum = P.ReduceSum(keep_dims=False)
  170. self.mul = P.Mul()
  171. self.squeeze = P.Squeeze(axis=1)
  172. self.zeroslike = P.ZerosLike()
  173. self.concat = P.Concat(axis=1)
  174. self.reciprocal = P.Reciprocal()
  175. def construct(self, batch_users, batch_items, labels, valid_pt_mask):
  176. predict = self.network(batch_users, batch_items)
  177. predict = self.concat((self.zeroslike(predict), predict))
  178. labels = self.squeeze(labels)
  179. loss = self.loss(predict, labels)
  180. loss = self.mul(loss, self.squeeze(valid_pt_mask))
  181. mean_loss = self.mul(self.reducesum(loss), self.reciprocal(self.reducesum(valid_pt_mask)))
  182. return mean_loss
  183. class TrainStepWrap(nn.Cell):
  184. """
  185. TrainStepWrap definition
  186. """
  187. def __init__(self, network, total_steps=1, sens=16384.0):
  188. super(TrainStepWrap, self).__init__(auto_prefix=False)
  189. self.network = network
  190. self.network.set_train()
  191. self.network.add_flags(defer_inline=True)
  192. self.weights = ParameterTuple(network.trainable_params())
  193. lr = dynamic_lr(0.01, total_steps, 5000)
  194. self.optimizer = nn.Adam(self.weights,
  195. learning_rate=lr,
  196. beta1=0.9,
  197. beta2=0.999,
  198. eps=1e-8,
  199. loss_scale=sens)
  200. self.hyper_map = C.HyperMap()
  201. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  202. self.sens = sens
  203. self.reducer_flag = False
  204. self.grad_reducer = None
  205. parallel_mode = _get_parallel_mode()
  206. if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
  207. self.reducer_flag = True
  208. if self.reducer_flag:
  209. mean = _get_gradients_mean()
  210. degree = _get_device_num()
  211. self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
  212. def construct(self, batch_users, batch_items, labels, valid_pt_mask):
  213. weights = self.weights
  214. loss = self.network(batch_users, batch_items, labels, valid_pt_mask)
  215. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
  216. grads = self.grad(self.network, weights)(batch_users, batch_items, labels, valid_pt_mask, sens)
  217. if self.reducer_flag:
  218. # apply grad reducer on grads
  219. grads = self.grad_reducer(grads)
  220. return F.depend(loss, self.optimizer(grads))
  221. class PredictWithSigmoid(nn.Cell):
  222. """
  223. Predict definition
  224. """
  225. def __init__(self, network, k, num_eval_neg):
  226. super(PredictWithSigmoid, self).__init__()
  227. self.network = network
  228. self.topk = P.TopK(sorted=True)
  229. self.squeeze = P.Squeeze()
  230. self.k = k
  231. self.num_eval_neg = num_eval_neg
  232. self.gather = P.GatherV2()
  233. self.reshape = P.Reshape()
  234. self.reducesum = P.ReduceSum(keep_dims=False)
  235. self.notequal = P.NotEqual()
  236. def construct(self, batch_users, batch_items, duplicated_masks):
  237. predicts = self.network(batch_users, batch_items) # (bs, 1)
  238. predicts = self.reshape(predicts, (-1, self.num_eval_neg + 1)) # (num_user, 100)
  239. batch_items = self.reshape(batch_items, (-1, self.num_eval_neg + 1)) # (num_user, 100)
  240. duplicated_masks = self.reshape(duplicated_masks, (-1, self.num_eval_neg + 1)) # (num_user, 100)
  241. masks_sum = self.reducesum(duplicated_masks, 1)
  242. metric_weights = self.notequal(masks_sum, self.num_eval_neg) # (num_user)
  243. _, indices = self.topk(predicts, self.k) # (num_user, k)
  244. return indices, batch_items, metric_weights