Browse Source

Fix batch size check

tags/v1.1.0
huangxinjing 5 years ago
parent
commit
fe89ad2c49
2 changed files with 13 additions and 6 deletions
  1. +5
    -4
      model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py
  2. +8
    -2
      model_zoo/official/recommend/wide_and_deep_multitable/src/datasets.py

+ 5
- 4
model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py View File

@@ -121,12 +121,12 @@ def train_and_eval(config):
model = Model(train_net, eval_network=eval_net,
metrics={"auc": auc_metric})

eval_callback = EvalCallBack(
model, ds_eval, auc_metric, config)

# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")

eval_callback = EvalCallBack(
model, ds_eval, auc_metric, config)

callback = LossCallBack(config=config, per_print_times=20)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
keep_checkpoint_max=5, integrated_save=False)
@@ -146,10 +146,11 @@ if __name__ == "__main__":
wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init()
context.set_context(mode=context.GRAPH_MODE,
device_target=wide_deep_config.device_target, save_graphs=True)
device_target=wide_deep_config.device_target)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
init()
context.set_context(save_graphs_path='./graphs_of_device_id_' + str(get_rank()), save_graphs=True)
if wide_deep_config.sparse:
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)


+ 8
- 2
model_zoo/official/recommend/wide_and_deep_multitable/src/datasets.py View File

@@ -209,14 +209,20 @@ def _get_tf_dataset(data_dir,
shuffle=shuffle,
schema=schema,
num_parallel_workers=8)
if batch_size <= 0:
raise ValueError("Batch size should be a positive int value, but found {}".format(str(batch_size)))
if batch_size % line_per_sample != 0:
raise ValueError(
"Batch size should be a multiple of {}, but found {}".format(str(line_per_sample), str(batch_size)))

data_set = data_set.batch(int(batch_size / line_per_sample), drop_remainder=True)

operations_list = []
for key in columns_list:
operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key]))
print("ssssssssssssssssssssss---------------------" * 10)
print("input_shape_dict start logging")
print(input_shape_dict)
print("---------------------" * 10)
print("input_shape_dict end logging")
print(schema_dict)

def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):


Loading…
Cancel
Save