Browse Source

add backward unique mode of wide_and_deep

tags/v1.1.0
yao_yf 5 years ago
parent
commit
187873f975
8 changed files with 50 additions and 19 deletions
  1. +8
    -2
      mindspore/nn/layer/embedding.py
  2. +2
    -0
      model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh
  3. +1
    -1
      model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh
  4. +5
    -0
      model_zoo/official/recommend/wide_and_deep/src/config.py
  5. +21
    -10
      model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
  6. +5
    -2
      model_zoo/official/recommend/wide_and_deep/train_and_eval.py
  7. +3
    -2
      model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py
  8. +5
    -2
      model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py

+ 8
- 2
mindspore/nn/layer/embedding.py View File

@@ -168,13 +168,19 @@ class EmbeddingLookup(Cell):
TABLE_COLUMN_SLICE = "table_column_slice"

def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None):
target='CPU', slice_mode='batch_slice', manual_shapes=None,
max_norm=None, sparse=True):
super(EmbeddingLookup, self).__init__()
self.target = target
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
self.gatherv2 = P.GatherV2()
if not sparse and target == 'CPU':
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
if sparse:
self.gatherv2 = P.SparseGatherV2()
else:
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)


+ 2
- 0
model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh View File

@@ -43,6 +43,8 @@ do
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 &
elif [ $MODE == "field_slice_host_device_mix" ]; then
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 &
elif [ $MODE == "backward_unique" ]; then
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --sparse=1 >train_deep$i.log 2>&1 &
else
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 &
fi

+ 1
- 1
model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh View File

@@ -38,7 +38,7 @@ do
user=$(get_node_user ${cluster_config_path} ${node})
passwd=$(get_node_passwd ${cluster_config_path} ${node})
echo "------------------${user}@${node}---------------------"
if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ]; then
if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "backward_unique" ]; then
ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}"
else
echo "[ERROR] mode is wrong"


+ 5
- 0
model_zoo/official/recommend/wide_and_deep/src/config.py View File

@@ -47,6 +47,7 @@ def argparse_init():
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="tfrecord/mindrecord/hd5")
parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not")
parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not")
parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not")
return parser


@@ -84,6 +85,7 @@ class WideDeepConfig():
self.parameter_server = 0
self.field_slice = False
self.manual_shape = None
self.sparse = False

def argparse_init(self):
"""
@@ -118,3 +120,6 @@ class WideDeepConfig():
self.dataset_type = args.dataset_type
self.parameter_server = args.parameter_server
self.field_slice = bool(args.field_slice)
self.sparse = bool(args.sparse)
if self.host_device_mix == 1:
self.sparse = True

+ 21
- 10
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py View File

@@ -144,6 +144,7 @@ class WideDeepModel(nn.Cell):
if is_auto_parallel:
self.batch_size = self.batch_size * get_group_size()
is_field_slice = config.field_slice
sparse = config.sparse
self.field_size = config.field_size
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
@@ -197,13 +198,16 @@ class WideDeepModel(nn.Cell):
self.tile = P.Tile()
self.concat = P.Concat(axis=1)
self.cast = P.Cast()
if is_auto_parallel and host_device_mix and not is_field_slice:
if is_auto_parallel and sparse and not is_field_slice:
self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),))
self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1)))
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
target = 'DEVICE'
if host_device_mix:
target = 'CPU'
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE)
self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_reshape.add_prim_attr("skip_redistribution", True)
@@ -231,8 +235,10 @@ class WideDeepModel(nn.Cell):
self.deep_embeddinglookup.embedding_table.set_param_ps()
self.wide_embeddinglookup.embedding_table.set_param_ps()
else:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
target='DEVICE', sparse=sparse)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
target='DEVICE', sparse=sparse)
self.embedding_table = self.deep_embeddinglookup.embedding_table
def construct(self, id_hldr, wt_hldr):
@@ -272,9 +278,13 @@ class NetWithLossClass(nn.Cell):
super(NetWithLossClass, self).__init__(auto_prefix=False)
host_device_mix = bool(config.host_device_mix)
parameter_server = bool(config.parameter_server)
sparse = config.sparse
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) else parameter_server)
self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice)
else parameter_server)
if sparse:
self.no_l2loss = True
self.network = network
self.l2_coef = config.l2_coef
self.loss = P.SigmoidCrossEntropyWithLogits()
@@ -323,7 +333,7 @@ class TrainStepWrap(nn.Cell):
parameter_server (Bool): Whether run in parameter server mode. Default: False
"""
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False):
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, sparse=False):
super(TrainStepWrap, self).__init__()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
@@ -340,13 +350,14 @@ class TrainStepWrap(nn.Cell):
self.weights_w = ParameterTuple(weights_w)
self.weights_d = ParameterTuple(weights_d)
if (host_device_mix and is_auto_parallel) or parameter_server:
if (sparse and is_auto_parallel) or parameter_server:
self.optimizer_d = LazyAdam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
self.optimizer_w.target = "CPU"
self.optimizer_d.target = "CPU"
if host_device_mix or parameter_server:
self.optimizer_w.target = "CPU"
self.optimizer_d.target = "CPU"
else:
self.optimizer_d = Adam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)


+ 5
- 2
model_zoo/official/recommend/wide_and_deep/train_and_eval.py View File

@@ -31,7 +31,7 @@ def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)

loss_net = NetWithLossClass(WideDeep_net, config)
train_net = TrainStepWrap(loss_net)
train_net = TrainStepWrap(loss_net, sparse=config.sparse)
eval_net = PredictWithSigmoid(WideDeep_net)

return train_net, eval_net
@@ -67,6 +67,7 @@ def test_train_eval(config):
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
sparse = config.sparse
if config.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif config.dataset_type == "mindrecord":
@@ -97,7 +98,8 @@ def test_train_eval(config):
out = model.eval(ds_eval)
print("=====" * 5 + "model.eval() initialized: {}".format(out))
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb],
dataset_sink_mode=(not sparse))


if __name__ == "__main__":
@@ -105,4 +107,5 @@ if __name__ == "__main__":
wide_deep_config.argparse_init()

context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target)
context.set_context(enable_sparse=wide_deep_config.sparse)
test_train_eval(wide_deep_config)

+ 3
- 2
model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py View File

@@ -41,7 +41,7 @@ def get_WideDeep_net(config):
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(
loss_net, host_device_mix=bool(config.host_device_mix))
loss_net, host_device_mix=bool(config.host_device_mix), sparse=config.sparse)
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net
@@ -84,6 +84,7 @@ def train_and_eval(config):
else:
dataset_type = DataType.H5
host_device_mix = bool(config.host_device_mix)
sparse = config.sparse
print("epochs is {}".format(epochs))
if config.full_batch:
context.set_auto_parallel_context(full_batch=True)
@@ -134,7 +135,7 @@ def train_and_eval(config):
if not host_device_mix:
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train, callbacks=callback_list,
dataset_sink_mode=(not host_device_mix))
dataset_sink_mode=(not sparse))


if __name__ == "__main__":


+ 5
- 2
model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py View File

@@ -38,7 +38,7 @@ def get_WideDeep_net(config):
"""
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
train_net = TrainStepWrap(loss_net)
train_net = TrainStepWrap(loss_net, sparse=config.sparse)
eval_net = PredictWithSigmoid(WideDeep_net)
return train_net, eval_net

@@ -72,6 +72,7 @@ def train_and_eval(config):
set_seed(1000)
data_path = config.data_path
batch_size = config.batch_size
sparse = config.sparse
epochs = config.epochs
if config.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
@@ -111,7 +112,8 @@ def train_and_eval(config):
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train,
callbacks=callback_list,
sink_size=ds_train.get_dataset_size())
sink_size=ds_train.get_dataset_size(),
dataset_sink_mode=(not sparse))


if __name__ == "__main__":
@@ -119,6 +121,7 @@ if __name__ == "__main__":
wide_deep_config.argparse_init()

context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
context.set_context(enable_sparse=wide_deep_config.sparse)
init()
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,


Loading…
Cancel
Save