Browse Source

!7575 optimize the network structrue

Merge pull request !7575 from 吴书全/deepfm1021
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
42b25d0923
4 changed files with 83 additions and 55 deletions
  1. +5
    -5
      tests/st/model_zoo_tests/DeepFM/src/config.py
  2. +1
    -1
      tests/st/model_zoo_tests/DeepFM/src/dataset.py
  3. +76
    -48
      tests/st/model_zoo_tests/DeepFM/src/deepfm.py
  4. +1
    -1
      tests/st/model_zoo_tests/DeepFM/test_deepfm.py

+ 5
- 5
tests/st/model_zoo_tests/DeepFM/src/config.py View File

@@ -24,7 +24,7 @@ class DataConfig:
data_vocab_size = 184965 data_vocab_size = 184965
train_num_of_parts = 21 train_num_of_parts = 21
test_num_of_parts = 3 test_num_of_parts = 3
batch_size = 1000
batch_size = 16000
data_field_size = 39 data_field_size = 39
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5 # dataset format, 1: mindrecord, 2: tfrecord, 3: h5
data_format = 3 data_format = 3
@@ -38,7 +38,7 @@ class ModelConfig:
data_field_size = DataConfig.data_field_size data_field_size = DataConfig.data_field_size
data_vocab_size = DataConfig.data_vocab_size data_vocab_size = DataConfig.data_vocab_size
data_emb_dim = 80 data_emb_dim = 80
deep_layer_args = [[400, 400, 512], "relu"]
deep_layer_args = [[1024, 512, 256, 128], "relu"]
init_args = [-0.01, 0.01] init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal'] weight_bias_init = ['normal', 'normal']
keep_prob = 0.9 keep_prob = 0.9
@@ -49,9 +49,9 @@ class TrainConfig:
Define parameters of training. Define parameters of training.
""" """
batch_size = DataConfig.batch_size batch_size = DataConfig.batch_size
l2_coef = 1e-6
learning_rate = 1e-5
epsilon = 1e-8
l2_coef = 8e-5
learning_rate = 5e-4
epsilon = 5e-8
loss_scale = 1024.0 loss_scale = 1024.0
train_epochs = 3 train_epochs = 3
save_checkpoint = True save_checkpoint = True


+ 1
- 1
tests/st/model_zoo_tests/DeepFM/src/dataset.py View File

@@ -212,7 +212,7 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100
np.array(y).flatten().reshape(batch_size, 39), np.array(y).flatten().reshape(batch_size, 39),
np.array(z).flatten().reshape(batch_size, 1))), np.array(z).flatten().reshape(batch_size, 1))),
input_columns=['feat_ids', 'feat_vals', 'label'], input_columns=['feat_ids', 'feat_vals', 'label'],
columns_order=['feat_ids', 'feat_vals', 'label'],
column_order=['feat_ids', 'feat_vals', 'label'],
num_parallel_workers=8) num_parallel_workers=8)
ds = ds.repeat(epochs) ds = ds.repeat(epochs)
return ds return ds


+ 76
- 48
tests/st/model_zoo_tests/DeepFM/src/deepfm.py View File

@@ -24,9 +24,12 @@ from mindspore.ops import operations as P
from mindspore.nn import Dropout from mindspore.nn import Dropout
from mindspore.nn.optim import Adam from mindspore.nn.optim import Adam
from mindspore.nn.metrics import Metric from mindspore.nn.metrics import Metric
from mindspore import nn, ParameterTuple, Parameter
from mindspore.common.initializer import Uniform, initializer, Normal
from mindspore import nn, Tensor, ParameterTuple, Parameter
from mindspore.common.initializer import Uniform, initializer
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer


from .callback import EvalCallBack, LossCallBack from .callback import EvalCallBack, LossCallBack


@@ -60,7 +63,7 @@ class AUCMetric(Metric):
return auc return auc




def init_method(method, shape, name, max_val=0.01):
def init_method(method, shape, name, max_val=1.0):
""" """
The method of init parameters. The method of init parameters.


@@ -73,18 +76,18 @@ def init_method(method, shape, name, max_val=0.01):
Returns: Returns:
Parameter. Parameter.
""" """
if method in ['random', 'uniform']:
if method in ['uniform']:
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name) params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
elif method == "one": elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name) params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero': elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name) params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal": elif method == "normal":
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name)
return params return params




def init_var_dict(init_args, values):
def init_var_dict(init_args, var_list):
""" """
Init parameter. Init parameter.


@@ -96,17 +99,19 @@ def init_var_dict(init_args, values):
dict, a dict ot Parameter. dict, a dict ot Parameter.
""" """
var_map = {} var_map = {}
_, _max_val = init_args
for key, shape, init_flag in values:
_, max_val = init_args
for i, _ in enumerate(var_list):
key, shape, method = var_list[i]
if key not in var_map.keys(): if key not in var_map.keys():
if init_flag in ['random', 'uniform']:
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key)
elif init_flag == "one":
if method in ['random', 'uniform']:
var_map[key] = Parameter(initializer(Uniform(max_val), shape, ms_type), name=key)
elif method == "one":
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key) var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
elif init_flag == "zero":
elif method == "zero":
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key) var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
elif init_flag == 'normal':
var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key)
elif method == 'normal':
var_map[key] = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).
astype(dtype=np_type)), name=key)
return var_map return var_map




@@ -122,7 +127,9 @@ class DenseLayer(nn.Cell):
keep_prob (float): Dropout Layer keep_prob_rate; keep_prob (float): Dropout Layer keep_prob_rate;
scale_coef (float): input scale coefficient; scale_coef (float): input scale coefficient;
""" """
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0):

def __init__(self, input_dim, output_dim, weight_bias_init, act_str, scale_coef=1.0, convert_dtype=True,
use_act=True):
super(DenseLayer, self).__init__() super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init weight_init, bias_init = weight_bias_init
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight") self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
@@ -131,12 +138,15 @@ class DenseLayer(nn.Cell):
self.matmul = P.MatMul(transpose_b=False) self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
self.cast = P.Cast() self.cast = P.Cast()
self.dropout = Dropout(keep_prob=keep_prob)
self.dropout = Dropout(keep_prob=1.0)
self.mul = P.Mul() self.mul = P.Mul()
self.realDiv = P.RealDiv() self.realDiv = P.RealDiv()
self.scale_coef = scale_coef self.scale_coef = scale_coef
self.convert_dtype = convert_dtype
self.use_act = use_act


def _init_activation(self, act_str): def _init_activation(self, act_str):
"""Init activation function"""
act_str = act_str.lower() act_str = act_str.lower()
if act_str == "relu": if act_str == "relu":
act_func = P.ReLU() act_func = P.ReLU()
@@ -147,17 +157,23 @@ class DenseLayer(nn.Cell):
return act_func return act_func


def construct(self, x): def construct(self, x):
x = self.act_func(x)
if self.training:
x = self.dropout(x)
x = self.mul(x, self.scale_coef)
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
wx = self.matmul(x, weight)
wx = self.cast(wx, mstype.float32)
wx = self.realDiv(wx, self.scale_coef)
output = self.bias_add(wx, self.bias)
return output
"""Construct function"""
x = self.dropout(x)
if self.convert_dtype:
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
bias = self.cast(self.bias, mstype.float16)
wx = self.matmul(x, weight)
wx = self.bias_add(wx, bias)
if self.use_act:
wx = self.act_func(wx)
wx = self.cast(wx, mstype.float32)
else:
wx = self.matmul(x, self.weight)
wx = self.bias_add(wx, self.bias)
if self.use_act:
wx = self.act_func(wx)
return wx




class DeepFMModel(nn.Cell): class DeepFMModel(nn.Cell):
@@ -176,6 +192,7 @@ class DeepFMModel(nn.Cell):
(list[str], weight_bias_init=['random', 'zero']) (list[str], weight_bias_init=['random', 'zero'])
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8) keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
""" """

def __init__(self, config): def __init__(self, config):
super(DeepFMModel, self).__init__() super(DeepFMModel, self).__init__()


@@ -188,24 +205,24 @@ class DeepFMModel(nn.Cell):
self.weight_bias_init = config.weight_bias_init self.weight_bias_init = config.weight_bias_init
self.keep_prob = config.keep_prob self.keep_prob = config.keep_prob
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
('V_l2', [self.vocab_size, self.emb_dim], 'normal'),
('b', [1], 'normal')]
('V_l2', [self.vocab_size, self.emb_dim], 'normal')]
var_map = init_var_dict(self.init_args, init_acts) var_map = init_var_dict(self.init_args, init_acts)
self.fm_w = var_map["W_l2"] self.fm_w = var_map["W_l2"]
self.fm_b = var_map["b"]
self.embedding_table = var_map["V_l2"] self.embedding_table = var_map["V_l2"]
# Deep Layers
self.deep_input_dims = self.field_size * self.emb_dim + 1
" Deep Layers "
self.deep_input_dims = self.field_size * self.emb_dim
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
# FM, linear Layers
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False)
" FM, linear Layers "
self.Gatherv2 = P.GatherV2() self.Gatherv2 = P.GatherV2()
self.Mul = P.Mul() self.Mul = P.Mul()
self.ReduceSum = P.ReduceSum(keep_dims=False) self.ReduceSum = P.ReduceSum(keep_dims=False)
@@ -238,16 +255,14 @@ class DeepFMModel(nn.Cell):
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1) fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
fm_out = self.Reshape(fm_out, (-1, 1)) fm_out = self.Reshape(fm_out, (-1, 1))
# Deep layer # Deep layer
b = self.Reshape(self.fm_b, (1, 1))
b = self.Tile(b, (self.batch_size, 1))
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.Concat((deep_in, b))
deep_in = self.dense_layer_1(deep_in) deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in) deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in) deep_in = self.dense_layer_3(deep_in)
deep_out = self.dense_layer_4(deep_in)
deep_in = self.dense_layer_4(deep_in)
deep_out = self.dense_layer_5(deep_in)
out = linear_out + fm_out + deep_out out = linear_out + fm_out + deep_out
return out, fm_id_weight, fm_id_embs
return out, self.fm_w, self.embedding_table




class NetWithLossClass(nn.Cell): class NetWithLossClass(nn.Cell):
@@ -278,7 +293,7 @@ class TrainStepWrap(nn.Cell):
""" """
TrainStepWrap definition TrainStepWrap definition
""" """
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
def __init__(self, network, lr, eps, loss_scale=1000.0):
super(TrainStepWrap, self).__init__(auto_prefix=False) super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_train() self.network.set_train()
@@ -288,11 +303,24 @@ class TrainStepWrap(nn.Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = loss_scale self.sens = loss_scale


self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)

def construct(self, batch_ids, batch_wts, label): def construct(self, batch_ids, batch_wts, label):
weights = self.weights weights = self.weights
loss = self.network(batch_ids, batch_wts, label) loss = self.network(batch_ids, batch_wts, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens) grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads)) return F.depend(loss, self.optimizer(grads))






+ 1
- 1
tests/st/model_zoo_tests/DeepFM/test_deepfm.py View File

@@ -74,7 +74,7 @@ def test_deepfm():
export_loss_value = 0.51 export_loss_value = 0.51
print("loss_callback.loss:", loss_callback.loss) print("loss_callback.loss:", loss_callback.loss)
assert loss_callback.loss < export_loss_value assert loss_callback.loss < export_loss_value
export_per_step_time = 10.4
export_per_step_time = 40.0
print("time_callback:", time_callback.per_step_time) print("time_callback:", time_callback.per_step_time)
assert time_callback.per_step_time < export_per_step_time assert time_callback.per_step_time < export_per_step_time
print("*******test case pass!********") print("*******test case pass!********")

Loading…
Cancel
Save