Browse Source

!9777 fix auto parallet full batch

From: @limingqi107
Reviewed-by: @chujinjin,@cristoval
Signed-off-by: @cristoval
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f7c339fce1
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py

+ 2
- 1
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py View File

@@ -438,8 +438,9 @@ class PredictWithSigmoid(nn.Cell):
self.network = network self.network = network
self.sigmoid = P.Sigmoid() self.sigmoid = P.Sigmoid()
parallel_mode = context.get_auto_parallel_context("parallel_mode") parallel_mode = context.get_auto_parallel_context("parallel_mode")
full_batch = context.get_auto_parallel_context("full_batch")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
if is_auto_parallel and full_batch:
self.sigmoid.shard(((1, 1),)) self.sigmoid.shard(((1, 1),))
def construct(self, batch_ids, batch_wts, labels): def construct(self, batch_ids, batch_wts, labels):


Loading…
Cancel
Save