|
|
|
@@ -274,7 +274,9 @@ def _reset_op_id(): |
|
|
|
def _parallel_predict_check(): |
|
|
|
"""validate parallel model prediction""" |
|
|
|
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): |
|
|
|
if not context.get_auto_parallel_context("full_batch"): |
|
|
|
dataset_strategy = context.get_auto_parallel_context("dataset_strategy") |
|
|
|
is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch")) |
|
|
|
if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp: |
|
|
|
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.') |
|
|
|
|
|
|
|
|
|
|
|
|