diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 978b1be288..0b0d5f057b 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -438,8 +438,9 @@ class PredictWithSigmoid(nn.Cell): self.network = network self.sigmoid = P.Sigmoid() parallel_mode = context.get_auto_parallel_context("parallel_mode") + full_batch = context.get_auto_parallel_context("full_batch") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - if is_auto_parallel: + if is_auto_parallel and full_batch: self.sigmoid.shard(((1, 1),)) def construct(self, batch_ids, batch_wts, labels):