Merge pull request !5309 from yao_yf/wide_and_deep_field_slicetags/v1.0.0
| @@ -41,6 +41,8 @@ do | |||
| cd ${execute_path}/device_$RANK_ID || exit | |||
| if [ $MODE == "host_device_mix" ]; then | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 & | |||
| elif [ $MODE == "field_slice_host_device_mix" ]; then | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 & | |||
| else | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 & | |||
| fi | |||
| @@ -38,7 +38,7 @@ do | |||
| user=$(get_node_user ${cluster_config_path} ${node}) | |||
| passwd=$(get_node_passwd ${cluster_config_path} ${node}) | |||
| echo "------------------${user}@${node}---------------------" | |||
| if [ $MODE == "host_device_mix" ]; then | |||
| if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ]; then | |||
| ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}" | |||
| else | |||
| echo "[ERROR] mode is wrong" | |||
| @@ -25,7 +25,7 @@ def argparse_init(): | |||
| parser.add_argument("--data_path", type=str, default="./test_raw_data/", | |||
| help="This should be set to the same directory given to the data_download's data_dir argument") | |||
| parser.add_argument("--epochs", type=int, default=15, help="Total train epochs") | |||
| parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ") | |||
| parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ") | |||
| parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.") | |||
| parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") | |||
| parser.add_argument("--field_size", type=int, default=39, help="The number of features.") | |||
| @@ -46,6 +46,7 @@ def argparse_init(): | |||
| parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not") | |||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", help="tfrecord/mindrecord/hd5") | |||
| parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not") | |||
| parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not") | |||
| return parser | |||
| @@ -81,6 +82,8 @@ class WideDeepConfig(): | |||
| self.host_device_mix = 0 | |||
| self.dataset_type = "tfrecord" | |||
| self.parameter_server = 0 | |||
| self.field_slice = False | |||
| self.manual_shape = None | |||
| def argparse_init(self): | |||
| """ | |||
| @@ -91,7 +94,7 @@ class WideDeepConfig(): | |||
| self.device_target = args.device_target | |||
| self.data_path = args.data_path | |||
| self.epochs = args.epochs | |||
| self.full_batch = args.full_batch | |||
| self.full_batch = bool(args.full_batch) | |||
| self.batch_size = args.batch_size | |||
| self.eval_batch_size = args.eval_batch_size | |||
| self.field_size = args.field_size | |||
| @@ -114,3 +117,4 @@ class WideDeepConfig(): | |||
| self.host_device_mix = args.host_device_mix | |||
| self.dataset_type = args.dataset_type | |||
| self.parameter_server = args.parameter_server | |||
| self.field_slice = bool(args.field_slice) | |||
| @@ -23,6 +23,7 @@ import pandas as pd | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.common.dtype as mstype | |||
| class DataType(Enum): | |||
| """ | |||
| Enumerate supported dataset format. | |||
| @@ -83,9 +84,9 @@ class H5Dataset(): | |||
| yield os.path.join(self._hdf_data_dir, | |||
| self._file_prefix + '_input_part_' + str( | |||
| p) + '.h5'), \ | |||
| os.path.join(self._hdf_data_dir, | |||
| self._file_prefix + '_output_part_' + str( | |||
| p) + '.h5'), i + 1 == len(parts) | |||
| os.path.join(self._hdf_data_dir, | |||
| self._file_prefix + '_output_part_' + str( | |||
| p) + '.h5'), i + 1 == len(parts) | |||
| def _generator(self, X, y, batch_size, shuffle=True): | |||
| """ | |||
| @@ -169,8 +170,41 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000): | |||
| return ds | |||
| def _padding_func(batch_size, manual_shape, target_column, field_size=39): | |||
| """ | |||
| get padding_func | |||
| """ | |||
| if manual_shape: | |||
| generate_concat_offset = [item[0]+item[1] for item in manual_shape] | |||
| part_size = int(target_column / len(generate_concat_offset)) | |||
| filled_value = [] | |||
| for i in range(field_size, target_column): | |||
| filled_value.append(generate_concat_offset[i//part_size]-1) | |||
| print("Filed Value:", filled_value) | |||
| def padding_func(x, y, z): | |||
| x = np.array(x).flatten().reshape(batch_size, field_size) | |||
| y = np.array(y).flatten().reshape(batch_size, field_size) | |||
| z = np.array(z).flatten().reshape(batch_size, 1) | |||
| x_id = np.ones((batch_size, target_column - field_size), | |||
| dtype=np.int32) * filled_value | |||
| x_id = np.concatenate([x, x_id.astype(dtype=np.int32)], axis=1) | |||
| mask = np.concatenate( | |||
| [y, np.zeros((batch_size, target_column-39), dtype=np.float32)], axis=1) | |||
| return (x_id, mask, z) | |||
| else: | |||
| def padding_func(x, y, z): | |||
| x = np.array(x).flatten().reshape(batch_size, field_size) | |||
| y = np.array(y).flatten().reshape(batch_size, field_size) | |||
| z = np.array(z).flatten().reshape(batch_size, 1) | |||
| return (x, y, z) | |||
| return padding_func | |||
| def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, | |||
| line_per_sample=1000, rank_size=None, rank_id=None): | |||
| line_per_sample=1000, rank_size=None, rank_id=None, | |||
| manual_shape=None, target_column=40): | |||
| """ | |||
| get_tf_dataset | |||
| """ | |||
| @@ -189,21 +223,22 @@ def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, | |||
| ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8, | |||
| num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True) | |||
| else: | |||
| ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8) | |||
| ds = de.TFRecordDataset(dataset_files=dataset_files, | |||
| shuffle=shuffle, schema=schema, num_parallel_workers=8) | |||
| ds = ds.batch(int(batch_size / line_per_sample), | |||
| drop_remainder=True) | |||
| ds = ds.map(operations=(lambda x, y, z: ( | |||
| np.array(x).flatten().reshape(batch_size, 39), | |||
| np.array(y).flatten().reshape(batch_size, 39), | |||
| np.array(z).flatten().reshape(batch_size, 1))), | |||
| ds = ds.map(operations=_padding_func(batch_size, manual_shape, target_column), | |||
| input_columns=['feat_ids', 'feat_vals', 'label'], | |||
| columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8) | |||
| #if train_mode: | |||
| # if train_mode: | |||
| ds = ds.repeat(epochs) | |||
| return ds | |||
| def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000, | |||
| line_per_sample=1000, rank_size=None, rank_id=None): | |||
| line_per_sample=1000, rank_size=None, rank_id=None, | |||
| manual_shape=None, target_column=40): | |||
| """ | |||
| Get dataset with mindrecord format. | |||
| @@ -233,9 +268,7 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100 | |||
| columns_list=['feat_ids', 'feat_vals', 'label'], | |||
| shuffle=shuffle, num_parallel_workers=8) | |||
| ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) | |||
| ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39), | |||
| np.array(y).flatten().reshape(batch_size, 39), | |||
| np.array(z).flatten().reshape(batch_size, 1))), | |||
| ds = ds.map(_padding_func(batch_size, manual_shape, target_column), | |||
| input_columns=['feat_ids', 'feat_vals', 'label'], | |||
| columns_order=['feat_ids', 'feat_vals', 'label'], | |||
| num_parallel_workers=8) | |||
| @@ -243,18 +276,84 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100 | |||
| return ds | |||
| def _get_vocab_size(target_column_number, worker_size, total_vocab_size, multiply=False, per_vocab_size=None): | |||
| """ | |||
| get_vocab_size | |||
| """ | |||
| # Only 39 | |||
| inidival_vocabs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 691, 540, 20855, 23639, 182, 15, | |||
| 10091, 347, 4, 16366, 4494, 21293, 3103, 27, 6944, 22366, 11, 3267, 1610, | |||
| 5, 21762, 14, 15, 15030, 61, 12220] | |||
| new_vocabs = inidival_vocabs + [1] * \ | |||
| (target_column_number - len(inidival_vocabs)) | |||
| part_size = int(target_column_number / worker_size) | |||
| # According to the workers, we merge some fields into the same part | |||
| new_vocab_size = [] | |||
| for i in range(0, target_column_number, part_size): | |||
| new_vocab_size.append(sum(new_vocabs[i: i + part_size])) | |||
| index_offsets = [0] | |||
| # The gold feature numbers ared used to caculate the offset | |||
| features = [item for item in new_vocab_size] | |||
| # According to the per_vocab_size, maxize the vocab size | |||
| if per_vocab_size is not None: | |||
| new_vocab_size = [per_vocab_size] * worker_size | |||
| else: | |||
| # Expands the vocabulary of each field by the multiplier | |||
| if multiply is True: | |||
| cur_sum = sum(new_vocab_size) | |||
| k = total_vocab_size/cur_sum | |||
| new_vocab_size = [ | |||
| math.ceil(int(item*k)/worker_size)*worker_size for item in new_vocab_size] | |||
| new_vocab_size = [(item // 8 + 1)*8 for item in new_vocab_size] | |||
| else: | |||
| if total_vocab_size > sum(new_vocab_size): | |||
| new_vocab_size[-1] = total_vocab_size - \ | |||
| sum(new_vocab_size[:-1]) | |||
| new_vocab_size = [item for item in new_vocab_size] | |||
| else: | |||
| raise ValueError( | |||
| "Please providede the correct vocab size, now is {}".format(total_vocab_size)) | |||
| for i in range(worker_size-1): | |||
| off = index_offsets[i] + features[i] | |||
| index_offsets.append(off) | |||
| print("the offset: ", index_offsets) | |||
| manual_shape = tuple( | |||
| ((new_vocab_size[i], index_offsets[i]) for i in range(worker_size))) | |||
| vocab_total = sum(new_vocab_size) | |||
| return manual_shape, vocab_total | |||
| def compute_manual_shape(config, worker_size): | |||
| target_column = (config.field_size // worker_size + 1) * worker_size | |||
| config.field_size = target_column | |||
| manual_shape, vocab_total = _get_vocab_size(target_column, worker_size, total_vocab_size=config.vocab_size, | |||
| per_vocab_size=None, multiply=False) | |||
| config.manual_shape = manual_shape | |||
| config.vocab_size = int(vocab_total) | |||
| def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, | |||
| data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None): | |||
| data_type=DataType.TFRECORD, line_per_sample=1000, | |||
| rank_size=None, rank_id=None, manual_shape=None, target_column=40): | |||
| """ | |||
| create_dataset | |||
| """ | |||
| if data_type == DataType.TFRECORD: | |||
| return _get_tf_dataset(data_dir, train_mode, epochs, batch_size, | |||
| line_per_sample, rank_size=rank_size, rank_id=rank_id) | |||
| line_per_sample, rank_size=rank_size, rank_id=rank_id, | |||
| manual_shape=manual_shape, target_column=target_column) | |||
| if data_type == DataType.MINDRECORD: | |||
| return _get_mindrecord_dataset(data_dir, train_mode, epochs, | |||
| batch_size, line_per_sample, | |||
| rank_size, rank_id) | |||
| return _get_mindrecord_dataset(data_dir, train_mode, epochs, batch_size, | |||
| line_per_sample, rank_size=rank_size, rank_id=rank_id, | |||
| manual_shape=manual_shape, target_column=target_column) | |||
| if rank_size > 1: | |||
| raise RuntimeError("please use tfrecord dataset.") | |||
| @@ -143,6 +143,7 @@ class WideDeepModel(nn.Cell): | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| if is_auto_parallel: | |||
| self.batch_size = self.batch_size * get_group_size() | |||
| is_field_slice = config.field_slice | |||
| self.field_size = config.field_size | |||
| self.vocab_size = config.vocab_size | |||
| self.emb_dim = config.emb_dim | |||
| @@ -196,11 +197,10 @@ class WideDeepModel(nn.Cell): | |||
| self.tile = P.Tile() | |||
| self.concat = P.Concat(axis=1) | |||
| self.cast = P.Cast() | |||
| if is_auto_parallel and host_device_mix: | |||
| if is_auto_parallel and host_device_mix and not is_field_slice: | |||
| self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) | |||
| self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) | |||
| self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) | |||
| self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size) | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | |||
| slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | |||
| @@ -209,9 +209,20 @@ class WideDeepModel(nn.Cell): | |||
| self.deep_reshape.add_prim_attr("skip_redistribution", True) | |||
| self.reduce_sum.add_prim_attr("cross_batch", True) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| elif host_device_mix: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) | |||
| elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape: | |||
| manual_shapes = tuple((s[0] for s in config.manual_shape)) | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, | |||
| slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, | |||
| manual_shapes=manual_shapes) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, | |||
| slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, | |||
| manual_shapes=manual_shapes) | |||
| self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1))) | |||
| self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1))) | |||
| self.reduce_sum.set_strategy(((1, get_group_size(), 1),)) | |||
| self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) | |||
| self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) | |||
| self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| elif parameter_server: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||
| @@ -263,7 +274,7 @@ class NetWithLossClass(nn.Cell): | |||
| parameter_server = bool(config.parameter_server) | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server) | |||
| self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) else parameter_server) | |||
| self.network = network | |||
| self.l2_coef = config.l2_coef | |||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||
| @@ -27,12 +27,13 @@ from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from src.callbacks import LossCallBack, EvalCallBack | |||
| from src.datasets import create_dataset, DataType | |||
| from src.datasets import create_dataset, DataType, compute_manual_shape | |||
| from src.metrics import AUCMetric | |||
| from src.config import WideDeepConfig | |||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||
| def get_WideDeep_net(config): | |||
| """ | |||
| Get network of wide&deep model. | |||
| @@ -40,7 +41,8 @@ def get_WideDeep_net(config): | |||
| WideDeep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| loss_net = VirtualDatasetCellTriple(loss_net) | |||
| train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix)) | |||
| train_net = TrainStepWrap( | |||
| loss_net, host_device_mix=bool(config.host_device_mix)) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| eval_net = VirtualDatasetCellTriple(eval_net) | |||
| return train_net, eval_net | |||
| @@ -50,6 +52,7 @@ class ModelBuilder(): | |||
| """ | |||
| ModelBuilder | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| @@ -86,10 +89,19 @@ def train_and_eval(config): | |||
| if config.full_batch: | |||
| context.set_auto_parallel_context(full_batch=True) | |||
| de.config.set_seed(1) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size*get_group_size(), data_type=dataset_type) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=batch_size*get_group_size(), data_type=dataset_type) | |||
| if config.field_slice: | |||
| compute_manual_shape(config, get_group_size()) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size*get_group_size(), data_type=dataset_type, | |||
| manual_shape=config.manual_shape, target_column=config.field_size) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=batch_size*get_group_size(), data_type=dataset_type, | |||
| manual_shape=config.manual_shape, target_column=config.field_size) | |||
| else: | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size*get_group_size(), data_type=dataset_type) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=batch_size*get_group_size(), data_type=dataset_type) | |||
| else: | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), | |||
| @@ -106,9 +118,11 @@ def train_and_eval(config): | |||
| train_net.set_train() | |||
| auc_metric = AUCMetric() | |||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | |||
| model = Model(train_net, eval_network=eval_net, | |||
| metrics={"auc": auc_metric}) | |||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) | |||
| eval_callback = EvalCallBack( | |||
| model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) | |||
| callback = LossCallBack(config=config, per_print_times=20) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, | |||
| @@ -116,16 +130,19 @@ def train_and_eval(config): | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) | |||
| callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] | |||
| callback_list = [TimeMonitor( | |||
| ds_train.get_dataset_size()), eval_callback, callback] | |||
| if not host_device_mix: | |||
| callback_list.append(ckpoint_cb) | |||
| model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix)) | |||
| model.train(epochs, ds_train, callbacks=callback_list, | |||
| dataset_sink_mode=(not host_device_mix)) | |||
| if __name__ == "__main__": | |||
| wide_deep_config = WideDeepConfig() | |||
| wide_deep_config.argparse_init() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target=wide_deep_config.device_target, save_graphs=True) | |||
| context.set_context(variable_memory_max_size="24GB") | |||
| context.set_context(enable_sparse=True) | |||
| set_multi_subgraphs() | |||
| @@ -134,7 +151,9 @@ if __name__ == "__main__": | |||
| elif wide_deep_config.device_target == "GPU": | |||
| init("nccl") | |||
| if wide_deep_config.host_device_mix == 1: | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) | |||
| context.set_auto_parallel_context( | |||
| parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) | |||
| else: | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) | |||
| context.set_auto_parallel_context( | |||
| parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) | |||
| train_and_eval(wide_deep_config) | |||
| @@ -101,12 +101,8 @@ def train_and_eval(config): | |||
| callback = LossCallBack(config=config) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||
| if config.device_target == "Ascend": | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| elif config.device_target == "GPU": | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()), | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| out = model.eval(ds_eval) | |||
| print("=====" * 5 + "model.eval() initialized: {}".format(out)) | |||
| callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] | |||
| @@ -103,14 +103,13 @@ def train_and_eval(config): | |||
| callback = LossCallBack(config=config) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||
| if config.device_target == "Ascend": | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| elif config.device_target == "GPU": | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()), | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] | |||
| if get_rank() == 0: | |||
| callback_list.append(ckpoint_cb) | |||
| model.train(epochs, ds_train, | |||
| callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], | |||
| callbacks=callback_list, | |||
| dataset_sink_mode=(not parameter_server)) | |||