| @@ -168,13 +168,19 @@ class EmbeddingLookup(Cell): | |||
| TABLE_COLUMN_SLICE = "table_column_slice" | |||
| def __init__(self, vocab_size, embedding_size, param_init='normal', | |||
| target='CPU', slice_mode='batch_slice', manual_shapes=None, max_norm=None): | |||
| target='CPU', slice_mode='batch_slice', manual_shapes=None, | |||
| max_norm=None, sparse=True): | |||
| super(EmbeddingLookup, self).__init__() | |||
| self.target = target | |||
| if target not in ('CPU', 'DEVICE'): | |||
| raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' | |||
| + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | |||
| self.gatherv2 = P.GatherV2() | |||
| if not sparse and target == 'CPU': | |||
| raise ValueError('When target is CPU, embedding_lookup must be sparse.') | |||
| if sparse: | |||
| self.gatherv2 = P.SparseGatherV2() | |||
| else: | |||
| self.gatherv2 = P.GatherV2() | |||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | |||
| self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) | |||
| self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) | |||
| @@ -43,6 +43,8 @@ do | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 & | |||
| elif [ $MODE == "field_slice_host_device_mix" ]; then | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 & | |||
| elif [ $MODE == "backward_unique" ]; then | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --sparse=1 >train_deep$i.log 2>&1 & | |||
| else | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 & | |||
| fi | |||
| @@ -38,7 +38,7 @@ do | |||
| user=$(get_node_user ${cluster_config_path} ${node}) | |||
| passwd=$(get_node_passwd ${cluster_config_path} ${node}) | |||
| echo "------------------${user}@${node}---------------------" | |||
| if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ]; then | |||
| if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "backward_unique" ]; then | |||
| ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}" | |||
| else | |||
| echo "[ERROR] mode is wrong" | |||
| @@ -47,6 +47,7 @@ def argparse_init(): | |||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", help="tfrecord/mindrecord/hd5") | |||
| parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not") | |||
| parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not") | |||
| parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not") | |||
| return parser | |||
| @@ -84,6 +85,7 @@ class WideDeepConfig(): | |||
| self.parameter_server = 0 | |||
| self.field_slice = False | |||
| self.manual_shape = None | |||
| self.sparse = False | |||
| def argparse_init(self): | |||
| """ | |||
| @@ -118,3 +120,6 @@ class WideDeepConfig(): | |||
| self.dataset_type = args.dataset_type | |||
| self.parameter_server = args.parameter_server | |||
| self.field_slice = bool(args.field_slice) | |||
| self.sparse = bool(args.sparse) | |||
| if self.host_device_mix == 1: | |||
| self.sparse = True | |||
| @@ -144,6 +144,7 @@ class WideDeepModel(nn.Cell): | |||
| if is_auto_parallel: | |||
| self.batch_size = self.batch_size * get_group_size() | |||
| is_field_slice = config.field_slice | |||
| sparse = config.sparse | |||
| self.field_size = config.field_size | |||
| self.vocab_size = config.vocab_size | |||
| self.emb_dim = config.emb_dim | |||
| @@ -197,13 +198,16 @@ class WideDeepModel(nn.Cell): | |||
| self.tile = P.Tile() | |||
| self.concat = P.Concat(axis=1) | |||
| self.cast = P.Cast() | |||
| if is_auto_parallel and host_device_mix and not is_field_slice: | |||
| if is_auto_parallel and sparse and not is_field_slice: | |||
| self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) | |||
| self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) | |||
| self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | |||
| target = 'DEVICE' | |||
| if host_device_mix: | |||
| target = 'CPU' | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) | |||
| self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) | |||
| self.deep_reshape.add_prim_attr("skip_redistribution", True) | |||
| @@ -231,8 +235,10 @@ class WideDeepModel(nn.Cell): | |||
| self.deep_embeddinglookup.embedding_table.set_param_ps() | |||
| self.wide_embeddinglookup.embedding_table.set_param_ps() | |||
| else: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE') | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE') | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | |||
| target='DEVICE', sparse=sparse) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | |||
| target='DEVICE', sparse=sparse) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| def construct(self, id_hldr, wt_hldr): | |||
| @@ -272,9 +278,13 @@ class NetWithLossClass(nn.Cell): | |||
| super(NetWithLossClass, self).__init__(auto_prefix=False) | |||
| host_device_mix = bool(config.host_device_mix) | |||
| parameter_server = bool(config.parameter_server) | |||
| sparse = config.sparse | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) else parameter_server) | |||
| self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) | |||
| else parameter_server) | |||
| if sparse: | |||
| self.no_l2loss = True | |||
| self.network = network | |||
| self.l2_coef = config.l2_coef | |||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||
| @@ -323,7 +333,7 @@ class TrainStepWrap(nn.Cell): | |||
| parameter_server (Bool): Whether run in parameter server mode. Default: False | |||
| """ | |||
| def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False): | |||
| def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, sparse=False): | |||
| super(TrainStepWrap, self).__init__() | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| @@ -340,13 +350,14 @@ class TrainStepWrap(nn.Cell): | |||
| self.weights_w = ParameterTuple(weights_w) | |||
| self.weights_d = ParameterTuple(weights_d) | |||
| if (host_device_mix and is_auto_parallel) or parameter_server: | |||
| if (sparse and is_auto_parallel) or parameter_server: | |||
| self.optimizer_d = LazyAdam( | |||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||
| self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | |||
| l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) | |||
| self.optimizer_w.target = "CPU" | |||
| self.optimizer_d.target = "CPU" | |||
| if host_device_mix or parameter_server: | |||
| self.optimizer_w.target = "CPU" | |||
| self.optimizer_d.target = "CPU" | |||
| else: | |||
| self.optimizer_d = Adam( | |||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||
| @@ -31,7 +31,7 @@ def get_WideDeep_net(config): | |||
| WideDeep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| train_net = TrainStepWrap(loss_net) | |||
| train_net = TrainStepWrap(loss_net, sparse=config.sparse) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| return train_net, eval_net | |||
| @@ -67,6 +67,7 @@ def test_train_eval(config): | |||
| data_path = config.data_path | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| sparse = config.sparse | |||
| if config.dataset_type == "tfrecord": | |||
| dataset_type = DataType.TFRECORD | |||
| elif config.dataset_type == "mindrecord": | |||
| @@ -97,7 +98,8 @@ def test_train_eval(config): | |||
| out = model.eval(ds_eval) | |||
| print("=====" * 5 + "model.eval() initialized: {}".format(out)) | |||
| model.train(epochs, ds_train, | |||
| callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) | |||
| callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], | |||
| dataset_sink_mode=(not sparse)) | |||
| if __name__ == "__main__": | |||
| @@ -105,4 +107,5 @@ if __name__ == "__main__": | |||
| wide_deep_config.argparse_init() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) | |||
| context.set_context(enable_sparse=wide_deep_config.sparse) | |||
| test_train_eval(wide_deep_config) | |||
| @@ -41,7 +41,7 @@ def get_WideDeep_net(config): | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| loss_net = VirtualDatasetCellTriple(loss_net) | |||
| train_net = TrainStepWrap( | |||
| loss_net, host_device_mix=bool(config.host_device_mix)) | |||
| loss_net, host_device_mix=bool(config.host_device_mix), sparse=config.sparse) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| eval_net = VirtualDatasetCellTriple(eval_net) | |||
| return train_net, eval_net | |||
| @@ -84,6 +84,7 @@ def train_and_eval(config): | |||
| else: | |||
| dataset_type = DataType.H5 | |||
| host_device_mix = bool(config.host_device_mix) | |||
| sparse = config.sparse | |||
| print("epochs is {}".format(epochs)) | |||
| if config.full_batch: | |||
| context.set_auto_parallel_context(full_batch=True) | |||
| @@ -134,7 +135,7 @@ def train_and_eval(config): | |||
| if not host_device_mix: | |||
| callback_list.append(ckpoint_cb) | |||
| model.train(epochs, ds_train, callbacks=callback_list, | |||
| dataset_sink_mode=(not host_device_mix)) | |||
| dataset_sink_mode=(not sparse)) | |||
| if __name__ == "__main__": | |||
| @@ -38,7 +38,7 @@ def get_WideDeep_net(config): | |||
| """ | |||
| WideDeep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| train_net = TrainStepWrap(loss_net) | |||
| train_net = TrainStepWrap(loss_net, sparse=config.sparse) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| return train_net, eval_net | |||
| @@ -72,6 +72,7 @@ def train_and_eval(config): | |||
| set_seed(1000) | |||
| data_path = config.data_path | |||
| batch_size = config.batch_size | |||
| sparse = config.sparse | |||
| epochs = config.epochs | |||
| if config.dataset_type == "tfrecord": | |||
| dataset_type = DataType.TFRECORD | |||
| @@ -111,7 +112,8 @@ def train_and_eval(config): | |||
| callback_list.append(ckpoint_cb) | |||
| model.train(epochs, ds_train, | |||
| callbacks=callback_list, | |||
| sink_size=ds_train.get_dataset_size()) | |||
| sink_size=ds_train.get_dataset_size(), | |||
| dataset_sink_mode=(not sparse)) | |||
| if __name__ == "__main__": | |||
| @@ -119,6 +121,7 @@ if __name__ == "__main__": | |||
| wide_deep_config.argparse_init() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) | |||
| context.set_context(enable_sparse=wide_deep_config.sparse) | |||
| init() | |||
| context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||