|
|
@@ -24,9 +24,12 @@ from mindspore.ops import operations as P |
|
|
from mindspore.nn import Dropout |
|
|
from mindspore.nn import Dropout |
|
|
from mindspore.nn.optim import Adam |
|
|
from mindspore.nn.optim import Adam |
|
|
from mindspore.nn.metrics import Metric |
|
|
from mindspore.nn.metrics import Metric |
|
|
from mindspore import nn, ParameterTuple, Parameter |
|
|
|
|
|
from mindspore.common.initializer import Uniform, initializer, Normal |
|
|
|
|
|
|
|
|
from mindspore import nn, Tensor, ParameterTuple, Parameter |
|
|
|
|
|
from mindspore.common.initializer import Uniform, initializer |
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig |
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig |
|
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean |
|
|
|
|
|
from mindspore.context import ParallelMode |
|
|
|
|
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer |
|
|
|
|
|
|
|
|
from .callback import EvalCallBack, LossCallBack |
|
|
from .callback import EvalCallBack, LossCallBack |
|
|
|
|
|
|
|
|
@@ -60,7 +63,7 @@ class AUCMetric(Metric): |
|
|
return auc |
|
|
return auc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_method(method, shape, name, max_val=0.01): |
|
|
|
|
|
|
|
|
def init_method(method, shape, name, max_val=1.0): |
|
|
""" |
|
|
""" |
|
|
The method of init parameters. |
|
|
The method of init parameters. |
|
|
|
|
|
|
|
|
@@ -73,18 +76,18 @@ def init_method(method, shape, name, max_val=0.01): |
|
|
Returns: |
|
|
Returns: |
|
|
Parameter. |
|
|
Parameter. |
|
|
""" |
|
|
""" |
|
|
if method in ['random', 'uniform']: |
|
|
|
|
|
|
|
|
if method in ['uniform']: |
|
|
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name) |
|
|
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name) |
|
|
elif method == "one": |
|
|
elif method == "one": |
|
|
params = Parameter(initializer("ones", shape, ms_type), name=name) |
|
|
params = Parameter(initializer("ones", shape, ms_type), name=name) |
|
|
elif method == 'zero': |
|
|
elif method == 'zero': |
|
|
params = Parameter(initializer("zeros", shape, ms_type), name=name) |
|
|
params = Parameter(initializer("zeros", shape, ms_type), name=name) |
|
|
elif method == "normal": |
|
|
elif method == "normal": |
|
|
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name) |
|
|
|
|
|
|
|
|
params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name) |
|
|
return params |
|
|
return params |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_var_dict(init_args, values): |
|
|
|
|
|
|
|
|
def init_var_dict(init_args, var_list): |
|
|
""" |
|
|
""" |
|
|
Init parameter. |
|
|
Init parameter. |
|
|
|
|
|
|
|
|
@@ -96,17 +99,19 @@ def init_var_dict(init_args, values): |
|
|
dict, a dict ot Parameter. |
|
|
dict, a dict ot Parameter. |
|
|
""" |
|
|
""" |
|
|
var_map = {} |
|
|
var_map = {} |
|
|
_, _max_val = init_args |
|
|
|
|
|
for key, shape, init_flag in values: |
|
|
|
|
|
|
|
|
_, max_val = init_args |
|
|
|
|
|
for i, _ in enumerate(var_list): |
|
|
|
|
|
key, shape, method = var_list[i] |
|
|
if key not in var_map.keys(): |
|
|
if key not in var_map.keys(): |
|
|
if init_flag in ['random', 'uniform']: |
|
|
|
|
|
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key) |
|
|
|
|
|
elif init_flag == "one": |
|
|
|
|
|
|
|
|
if method in ['random', 'uniform']: |
|
|
|
|
|
var_map[key] = Parameter(initializer(Uniform(max_val), shape, ms_type), name=key) |
|
|
|
|
|
elif method == "one": |
|
|
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key) |
|
|
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key) |
|
|
elif init_flag == "zero": |
|
|
|
|
|
|
|
|
elif method == "zero": |
|
|
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key) |
|
|
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key) |
|
|
elif init_flag == 'normal': |
|
|
|
|
|
var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key) |
|
|
|
|
|
|
|
|
elif method == 'normal': |
|
|
|
|
|
var_map[key] = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape). |
|
|
|
|
|
astype(dtype=np_type)), name=key) |
|
|
return var_map |
|
|
return var_map |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -122,7 +127,9 @@ class DenseLayer(nn.Cell): |
|
|
keep_prob (float): Dropout Layer keep_prob_rate; |
|
|
keep_prob (float): Dropout Layer keep_prob_rate; |
|
|
scale_coef (float): input scale coefficient; |
|
|
scale_coef (float): input scale coefficient; |
|
|
""" |
|
|
""" |
|
|
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, scale_coef=1.0, convert_dtype=True, |
|
|
|
|
|
use_act=True): |
|
|
super(DenseLayer, self).__init__() |
|
|
super(DenseLayer, self).__init__() |
|
|
weight_init, bias_init = weight_bias_init |
|
|
weight_init, bias_init = weight_bias_init |
|
|
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight") |
|
|
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight") |
|
|
@@ -131,12 +138,15 @@ class DenseLayer(nn.Cell): |
|
|
self.matmul = P.MatMul(transpose_b=False) |
|
|
self.matmul = P.MatMul(transpose_b=False) |
|
|
self.bias_add = P.BiasAdd() |
|
|
self.bias_add = P.BiasAdd() |
|
|
self.cast = P.Cast() |
|
|
self.cast = P.Cast() |
|
|
self.dropout = Dropout(keep_prob=keep_prob) |
|
|
|
|
|
|
|
|
self.dropout = Dropout(keep_prob=1.0) |
|
|
self.mul = P.Mul() |
|
|
self.mul = P.Mul() |
|
|
self.realDiv = P.RealDiv() |
|
|
self.realDiv = P.RealDiv() |
|
|
self.scale_coef = scale_coef |
|
|
self.scale_coef = scale_coef |
|
|
|
|
|
self.convert_dtype = convert_dtype |
|
|
|
|
|
self.use_act = use_act |
|
|
|
|
|
|
|
|
def _init_activation(self, act_str): |
|
|
def _init_activation(self, act_str): |
|
|
|
|
|
"""Init activation function""" |
|
|
act_str = act_str.lower() |
|
|
act_str = act_str.lower() |
|
|
if act_str == "relu": |
|
|
if act_str == "relu": |
|
|
act_func = P.ReLU() |
|
|
act_func = P.ReLU() |
|
|
@@ -147,17 +157,23 @@ class DenseLayer(nn.Cell): |
|
|
return act_func |
|
|
return act_func |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
x = self.act_func(x) |
|
|
|
|
|
if self.training: |
|
|
|
|
|
x = self.dropout(x) |
|
|
|
|
|
x = self.mul(x, self.scale_coef) |
|
|
|
|
|
x = self.cast(x, mstype.float16) |
|
|
|
|
|
weight = self.cast(self.weight, mstype.float16) |
|
|
|
|
|
wx = self.matmul(x, weight) |
|
|
|
|
|
wx = self.cast(wx, mstype.float32) |
|
|
|
|
|
wx = self.realDiv(wx, self.scale_coef) |
|
|
|
|
|
output = self.bias_add(wx, self.bias) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
"""Construct function""" |
|
|
|
|
|
x = self.dropout(x) |
|
|
|
|
|
if self.convert_dtype: |
|
|
|
|
|
x = self.cast(x, mstype.float16) |
|
|
|
|
|
weight = self.cast(self.weight, mstype.float16) |
|
|
|
|
|
bias = self.cast(self.bias, mstype.float16) |
|
|
|
|
|
wx = self.matmul(x, weight) |
|
|
|
|
|
wx = self.bias_add(wx, bias) |
|
|
|
|
|
if self.use_act: |
|
|
|
|
|
wx = self.act_func(wx) |
|
|
|
|
|
wx = self.cast(wx, mstype.float32) |
|
|
|
|
|
else: |
|
|
|
|
|
wx = self.matmul(x, self.weight) |
|
|
|
|
|
wx = self.bias_add(wx, self.bias) |
|
|
|
|
|
if self.use_act: |
|
|
|
|
|
wx = self.act_func(wx) |
|
|
|
|
|
return wx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepFMModel(nn.Cell): |
|
|
class DeepFMModel(nn.Cell): |
|
|
@@ -176,6 +192,7 @@ class DeepFMModel(nn.Cell): |
|
|
(list[str], weight_bias_init=['random', 'zero']) |
|
|
(list[str], weight_bias_init=['random', 'zero']) |
|
|
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8) |
|
|
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8) |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
def __init__(self, config): |
|
|
super(DeepFMModel, self).__init__() |
|
|
super(DeepFMModel, self).__init__() |
|
|
|
|
|
|
|
|
@@ -188,24 +205,24 @@ class DeepFMModel(nn.Cell): |
|
|
self.weight_bias_init = config.weight_bias_init |
|
|
self.weight_bias_init = config.weight_bias_init |
|
|
self.keep_prob = config.keep_prob |
|
|
self.keep_prob = config.keep_prob |
|
|
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), |
|
|
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), |
|
|
('V_l2', [self.vocab_size, self.emb_dim], 'normal'), |
|
|
|
|
|
('b', [1], 'normal')] |
|
|
|
|
|
|
|
|
('V_l2', [self.vocab_size, self.emb_dim], 'normal')] |
|
|
var_map = init_var_dict(self.init_args, init_acts) |
|
|
var_map = init_var_dict(self.init_args, init_acts) |
|
|
self.fm_w = var_map["W_l2"] |
|
|
self.fm_w = var_map["W_l2"] |
|
|
self.fm_b = var_map["b"] |
|
|
|
|
|
self.embedding_table = var_map["V_l2"] |
|
|
self.embedding_table = var_map["V_l2"] |
|
|
# Deep Layers |
|
|
|
|
|
self.deep_input_dims = self.field_size * self.emb_dim + 1 |
|
|
|
|
|
|
|
|
" Deep Layers " |
|
|
|
|
|
self.deep_input_dims = self.field_size * self.emb_dim |
|
|
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] |
|
|
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] |
|
|
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], |
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob) |
|
|
|
|
|
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], |
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob) |
|
|
|
|
|
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], |
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob) |
|
|
|
|
|
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], |
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob) |
|
|
|
|
|
# FM, linear Layers |
|
|
|
|
|
|
|
|
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init, |
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True) |
|
|
|
|
|
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init, |
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True) |
|
|
|
|
|
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init, |
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True) |
|
|
|
|
|
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init, |
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True) |
|
|
|
|
|
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init, |
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False) |
|
|
|
|
|
" FM, linear Layers " |
|
|
self.Gatherv2 = P.GatherV2() |
|
|
self.Gatherv2 = P.GatherV2() |
|
|
self.Mul = P.Mul() |
|
|
self.Mul = P.Mul() |
|
|
self.ReduceSum = P.ReduceSum(keep_dims=False) |
|
|
self.ReduceSum = P.ReduceSum(keep_dims=False) |
|
|
@@ -238,16 +255,14 @@ class DeepFMModel(nn.Cell): |
|
|
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1) |
|
|
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1) |
|
|
fm_out = self.Reshape(fm_out, (-1, 1)) |
|
|
fm_out = self.Reshape(fm_out, (-1, 1)) |
|
|
# Deep layer |
|
|
# Deep layer |
|
|
b = self.Reshape(self.fm_b, (1, 1)) |
|
|
|
|
|
b = self.Tile(b, (self.batch_size, 1)) |
|
|
|
|
|
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim)) |
|
|
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim)) |
|
|
deep_in = self.Concat((deep_in, b)) |
|
|
|
|
|
deep_in = self.dense_layer_1(deep_in) |
|
|
deep_in = self.dense_layer_1(deep_in) |
|
|
deep_in = self.dense_layer_2(deep_in) |
|
|
deep_in = self.dense_layer_2(deep_in) |
|
|
deep_in = self.dense_layer_3(deep_in) |
|
|
deep_in = self.dense_layer_3(deep_in) |
|
|
deep_out = self.dense_layer_4(deep_in) |
|
|
|
|
|
|
|
|
deep_in = self.dense_layer_4(deep_in) |
|
|
|
|
|
deep_out = self.dense_layer_5(deep_in) |
|
|
out = linear_out + fm_out + deep_out |
|
|
out = linear_out + fm_out + deep_out |
|
|
return out, fm_id_weight, fm_id_embs |
|
|
|
|
|
|
|
|
return out, self.fm_w, self.embedding_table |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NetWithLossClass(nn.Cell): |
|
|
class NetWithLossClass(nn.Cell): |
|
|
@@ -278,7 +293,7 @@ class TrainStepWrap(nn.Cell): |
|
|
""" |
|
|
""" |
|
|
TrainStepWrap definition |
|
|
TrainStepWrap definition |
|
|
""" |
|
|
""" |
|
|
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0): |
|
|
|
|
|
|
|
|
def __init__(self, network, lr, eps, loss_scale=1000.0): |
|
|
super(TrainStepWrap, self).__init__(auto_prefix=False) |
|
|
super(TrainStepWrap, self).__init__(auto_prefix=False) |
|
|
self.network = network |
|
|
self.network = network |
|
|
self.network.set_train() |
|
|
self.network.set_train() |
|
|
@@ -288,11 +303,24 @@ class TrainStepWrap(nn.Cell): |
|
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True) |
|
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True) |
|
|
self.sens = loss_scale |
|
|
self.sens = loss_scale |
|
|
|
|
|
|
|
|
|
|
|
self.reducer_flag = False |
|
|
|
|
|
self.grad_reducer = None |
|
|
|
|
|
parallel_mode = _get_parallel_mode() |
|
|
|
|
|
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): |
|
|
|
|
|
self.reducer_flag = True |
|
|
|
|
|
if self.reducer_flag: |
|
|
|
|
|
mean = _get_gradients_mean() |
|
|
|
|
|
degree = _get_device_num() |
|
|
|
|
|
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree) |
|
|
|
|
|
|
|
|
def construct(self, batch_ids, batch_wts, label): |
|
|
def construct(self, batch_ids, batch_wts, label): |
|
|
weights = self.weights |
|
|
weights = self.weights |
|
|
loss = self.network(batch_ids, batch_wts, label) |
|
|
loss = self.network(batch_ids, batch_wts, label) |
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) # |
|
|
|
|
|
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) # |
|
|
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens) |
|
|
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens) |
|
|
|
|
|
if self.reducer_flag: |
|
|
|
|
|
# apply grad reducer on grads |
|
|
|
|
|
grads = self.grad_reducer(grads) |
|
|
return F.depend(loss, self.optimizer(grads)) |
|
|
return F.depend(loss, self.optimizer(grads)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|