From 42b01afc19d5cd6fc1e6af55a167a88034899ed6 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Thu, 10 Dec 2020 15:30:26 +0800 Subject: [PATCH] fix auto parallet full batch --- .../official/recommend/wide_and_deep/src/wide_and_deep.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 978b1be288..0b0d5f057b 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -438,8 +438,9 @@ class PredictWithSigmoid(nn.Cell): self.network = network self.sigmoid = P.Sigmoid() parallel_mode = context.get_auto_parallel_context("parallel_mode") + full_batch = context.get_auto_parallel_context("full_batch") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - if is_auto_parallel: + if is_auto_parallel and full_batch: self.sigmoid.shard(((1, 1),)) def construct(self, batch_ids, batch_wts, labels):