| @@ -0,0 +1,311 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """wide and deep model""" | |||||
| from mindspore import nn | |||||
| from mindspore import Tensor, Parameter, ParameterTuple | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| # from mindspore.nn import Dropout | |||||
| from mindspore.nn.optim import Adam, FTRL | |||||
| # from mindspore.nn.metrics import Metric | |||||
| from mindspore.common.initializer import Uniform, initializer | |||||
| # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| import numpy as np | |||||
| np_type = np.float32 | |||||
| ms_type = mstype.float32 | |||||
| def init_method(method, shape, name, max_val=1.0): | |||||
| ''' | |||||
| parameter init method | |||||
| ''' | |||||
| if method in ['uniform']: | |||||
| params = Parameter(initializer( | |||||
| Uniform(max_val), shape, ms_type), name=name) | |||||
| elif method == "one": | |||||
| params = Parameter(initializer("ones", shape, ms_type), name=name) | |||||
| elif method == 'zero': | |||||
| params = Parameter(initializer("zeros", shape, ms_type), name=name) | |||||
| elif method == "normal": | |||||
| params = Parameter(Tensor(np.random.normal( | |||||
| loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name) | |||||
| return params | |||||
| def init_var_dict(init_args, in_vars): | |||||
| ''' | |||||
| var init function | |||||
| ''' | |||||
| var_map = {} | |||||
| _, _max_val = init_args | |||||
| for _, iterm in enumerate(in_vars): | |||||
| key, shape, method = iterm | |||||
| if key not in var_map.keys(): | |||||
| if method in ['random', 'uniform']: | |||||
| var_map[key] = Parameter(initializer( | |||||
| Uniform(_max_val), shape, ms_type), name=key) | |||||
| elif method == "one": | |||||
| var_map[key] = Parameter(initializer( | |||||
| "ones", shape, ms_type), name=key) | |||||
| elif method == "zero": | |||||
| var_map[key] = Parameter(initializer( | |||||
| "zeros", shape, ms_type), name=key) | |||||
| elif method == 'normal': | |||||
| var_map[key] = Parameter(Tensor(np.random.normal( | |||||
| loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=key) | |||||
| return var_map | |||||
| class DenseLayer(nn.Cell): | |||||
| """ | |||||
| Dense Layer for Deep Layer of WideDeep Model; | |||||
| Containing: activation, matmul, bias_add; | |||||
| Args: | |||||
| """ | |||||
| def __init__(self, input_dim, output_dim, weight_bias_init, act_str, | |||||
| keep_prob=0.7, scale_coef=1.0, convert_dtype=True): | |||||
| super(DenseLayer, self).__init__() | |||||
| weight_init, bias_init = weight_bias_init | |||||
| self.weight = init_method( | |||||
| weight_init, [input_dim, output_dim], name="weight") | |||||
| self.bias = init_method(bias_init, [output_dim], name="bias") | |||||
| self.act_func = self._init_activation(act_str) | |||||
| self.matmul = P.MatMul(transpose_b=False) | |||||
| self.bias_add = P.BiasAdd() | |||||
| self.cast = P.Cast() | |||||
| #self.dropout = Dropout(keep_prob=keep_prob) | |||||
| self.mul = P.Mul() | |||||
| self.realDiv = P.RealDiv() | |||||
| self.scale_coef = scale_coef | |||||
| self.convert_dtype = convert_dtype | |||||
| def _init_activation(self, act_str): | |||||
| act_str = act_str.lower() | |||||
| if act_str == "relu": | |||||
| act_func = P.ReLU() | |||||
| elif act_str == "sigmoid": | |||||
| act_func = P.Sigmoid() | |||||
| elif act_str == "tanh": | |||||
| act_func = P.Tanh() | |||||
| return act_func | |||||
| def construct(self, x): | |||||
| x = self.act_func(x) | |||||
| # if self.training: | |||||
| # x = self.dropout(x) | |||||
| x = self.mul(x, self.scale_coef) | |||||
| if self.convert_dtype: | |||||
| x = self.cast(x, mstype.float16) | |||||
| weight = self.cast(self.weight, mstype.float16) | |||||
| wx = self.matmul(x, weight) | |||||
| wx = self.cast(wx, mstype.float32) | |||||
| else: | |||||
| wx = self.matmul(x, self.weight) | |||||
| wx = self.realDiv(wx, self.scale_coef) | |||||
| output = self.bias_add(wx, self.bias) | |||||
| return output | |||||
| class WideDeepModel(nn.Cell): | |||||
| """ | |||||
| From paper: " Wide & Deep Learning for Recommender Systems" | |||||
| Args: | |||||
| config (Class): The default config of Wide&Deep | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(WideDeepModel, self).__init__() | |||||
| self.batch_size = config.batch_size | |||||
| self.field_size = config.field_size | |||||
| self.vocab_size = config.vocab_size | |||||
| self.emb_dim = config.emb_dim | |||||
| self.deep_layer_args = config.deep_layer_args | |||||
| self.deep_layer_dims_list, self.deep_layer_act = self.deep_layer_args | |||||
| self.init_args = config.init_args | |||||
| self.weight_init, self.bias_init = config.weight_bias_init | |||||
| self.weight_bias_init = config.weight_bias_init | |||||
| self.emb_init = config.emb_init | |||||
| self.drop_out = config.dropout_flag | |||||
| self.keep_prob = config.keep_prob | |||||
| self.deep_input_dims = self.field_size * self.emb_dim | |||||
| self.layer_dims = self.deep_layer_dims_list + [1] | |||||
| self.all_dim_list = [self.deep_input_dims] + self.layer_dims | |||||
| init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init), | |||||
| ('V_l2', [self.vocab_size, self.emb_dim], self.emb_init), | |||||
| ('Wide_b', [1], self.emb_init)] | |||||
| var_map = init_var_dict(self.init_args, init_acts) | |||||
| self.wide_w = var_map["Wide_w"] | |||||
| self.wide_b = var_map["Wide_b"] | |||||
| self.embedding_table = var_map["V_l2"] | |||||
| self.dense_layer_1 = DenseLayer(self.all_dim_list[0], | |||||
| self.all_dim_list[1], | |||||
| self.weight_bias_init, | |||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.dense_layer_2 = DenseLayer(self.all_dim_list[1], | |||||
| self.all_dim_list[2], | |||||
| self.weight_bias_init, | |||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.dense_layer_3 = DenseLayer(self.all_dim_list[2], | |||||
| self.all_dim_list[3], | |||||
| self.weight_bias_init, | |||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.dense_layer_4 = DenseLayer(self.all_dim_list[3], | |||||
| self.all_dim_list[4], | |||||
| self.weight_bias_init, | |||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.dense_layer_5 = DenseLayer(self.all_dim_list[4], | |||||
| self.all_dim_list[5], | |||||
| self.weight_bias_init, | |||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.gather_v2 = P.GatherV2() | |||||
| self.mul = P.Mul() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.reshape = P.Reshape() | |||||
| self.square = P.Square() | |||||
| self.shape = P.Shape() | |||||
| self.tile = P.Tile() | |||||
| self.concat = P.Concat(axis=1) | |||||
| self.cast = P.Cast() | |||||
| def construct(self, id_hldr, wt_hldr): | |||||
| """ | |||||
| Args: | |||||
| id_hldr: batch ids; | |||||
| wt_hldr: batch weights; | |||||
| """ | |||||
| mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | |||||
| # Wide layer | |||||
| wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0) | |||||
| wx = self.mul(wide_id_weight, mask) | |||||
| wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | |||||
| # Deep layer | |||||
| deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0) | |||||
| vx = self.mul(deep_id_embs, mask) | |||||
| deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) | |||||
| deep_in = self.dense_layer_1(deep_in) | |||||
| deep_in = self.dense_layer_2(deep_in) | |||||
| deep_in = self.dense_layer_3(deep_in) | |||||
| deep_in = self.dense_layer_4(deep_in) | |||||
| deep_out = self.dense_layer_5(deep_in) | |||||
| out = wide_out + deep_out | |||||
| return out, self.embedding_table | |||||
| class NetWithLossClass(nn.Cell): | |||||
| """" | |||||
| Provide WideDeep training loss through network. | |||||
| Args: | |||||
| network (Cell): The training network | |||||
| config (Class): WideDeep config | |||||
| """ | |||||
| def __init__(self, network, config): | |||||
| super(NetWithLossClass, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.l2_coef = config.l2_coef | |||||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||||
| self.square = P.Square() | |||||
| self.reduceMean_false = P.ReduceMean(keep_dims=False) | |||||
| self.reduceSum_false = P.ReduceSum(keep_dims=False) | |||||
| def construct(self, batch_ids, batch_wts, label): | |||||
| predict, embedding_table = self.network(batch_ids, batch_wts) | |||||
| log_loss = self.loss(predict, label) | |||||
| wide_loss = self.reduceMean_false(log_loss) | |||||
| l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 | |||||
| deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v | |||||
| return wide_loss, deep_loss | |||||
| class IthOutputCell(nn.Cell): | |||||
| def __init__(self, network, output_index): | |||||
| super(IthOutputCell, self).__init__() | |||||
| self.network = network | |||||
| self.output_index = output_index | |||||
| def construct(self, x1, x2, x3): | |||||
| predict = self.network(x1, x2, x3)[self.output_index] | |||||
| return predict | |||||
| class TrainStepWrap(nn.Cell): | |||||
| """ | |||||
| Encapsulation class of WideDeep network training. | |||||
| Append Adam and FTRL optimizers to the training network after that construct | |||||
| function can be called to create the backward graph. | |||||
| Args: | |||||
| network (Cell): the training network. Note that loss function should have been added. | |||||
| sens (Number): The adjust parameter. Default: 1000.0 | |||||
| """ | |||||
| def __init__(self, network, sens=1000.0): | |||||
| super(TrainStepWrap, self).__init__() | |||||
| self.network = network | |||||
| self.network.set_train() | |||||
| self.trainable_params = network.trainable_params() | |||||
| weights_w = [] | |||||
| weights_d = [] | |||||
| for params in self.trainable_params: | |||||
| if 'wide' in params.name: | |||||
| weights_w.append(params) | |||||
| else: | |||||
| weights_d.append(params) | |||||
| self.weights_w = ParameterTuple(weights_w) | |||||
| self.weights_d = ParameterTuple(weights_d) | |||||
| self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, | |||||
| l1=1e-8, l2=1e-8, initial_accum=1.0) | |||||
| self.optimizer_d = Adam( | |||||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.grad_w = C.GradOperation('grad_w', get_by_list=True, | |||||
| sens_param=True) | |||||
| self.grad_d = C.GradOperation('grad_d', get_by_list=True, | |||||
| sens_param=True) | |||||
| self.sens = sens | |||||
| self.loss_net_w = IthOutputCell(network, output_index=0) | |||||
| self.loss_net_d = IthOutputCell(network, output_index=1) | |||||
| def construct(self, batch_ids, batch_wts, label): | |||||
| weights_w = self.weights_w | |||||
| weights_d = self.weights_d | |||||
| loss_w, loss_d = self.network(batch_ids, batch_wts, label) | |||||
| sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) | |||||
| sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) | |||||
| grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts, | |||||
| label, sens_w) | |||||
| grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts, | |||||
| label, sens_d) | |||||
| return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, | |||||
| self.optimizer_d(grads_d)) | |||||
| class PredictWithSigmoid(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(PredictWithSigmoid, self).__init__() | |||||
| self.network = network | |||||
| self.sigmoid = P.Sigmoid() | |||||
| def construct(self, batch_ids, batch_wts, labels): | |||||
| logits, _, _, = self.network(batch_ids, batch_wts) | |||||
| pred_probs = self.sigmoid(logits) | |||||
| return logits, pred_probs, labels | |||||