Browse Source

!15643 insert virtual div only for first input of dropout do mask

From: @yangzhenzhang
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
pull/15643/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
3cfd58e8e0
2 changed files with 5 additions and 3 deletions
  1. +5
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +0
    -3
      tests/ut/python/parallel/test_dropout_do_mask.py

+ 5
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -987,6 +987,11 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
FuncGraphManagerPtr manager = func_graph->manager(); FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);


if (IsSomePrimitive(node, DROPOUT_DO_MASK)) {
MS_LOG(INFO) << "Handle dropout do mask, only insert the virtual div to input[0]";
node_size = 2;
}

for (size_t index = 1; index < node_size; ++index) { for (size_t index = 1; index < node_size; ++index) {
AnfNodePtr input = node->input(index); AnfNodePtr input = node->input(index);
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);


+ 0
- 3
tests/ut/python/parallel/test_dropout_do_mask.py View File

@@ -25,13 +25,11 @@ class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None): def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__() super().__init__()
self.mul = P.Mul().shard(strategy1) self.mul = P.Mul().shard(strategy1)
self.mul2 = P.Mul().shard(strategy1)
self.dropout_do_mask = P.DropoutDoMask().shard(strategy2) self.dropout_do_mask = P.DropoutDoMask().shard(strategy2)
self.dropout_gen_mask = P.DropoutGenMask() self.dropout_gen_mask = P.DropoutGenMask()
self.get_shape = P.Shape() self.get_shape = P.Shape()
self.cast = P.Cast() self.cast = P.Cast()
self.mul_weight = Parameter(mul_weight, "w1") self.mul_weight = Parameter(mul_weight, "w1")
self.mul_weight2 = Parameter(mul_weight, "w2")
self.keep_prob = Tensor(0.9) self.keep_prob = Tensor(0.9)


def construct(self, x, b): def construct(self, x, b):
@@ -41,7 +39,6 @@ class Net(Cell):
keep_prob = self.cast(self.keep_prob, dtype) keep_prob = self.cast(self.keep_prob, dtype)
mask = self.dropout_gen_mask(shape, keep_prob) mask = self.dropout_gen_mask(shape, keep_prob)
out = self.dropout_do_mask(out, mask, keep_prob) out = self.dropout_do_mask(out, mask, keep_prob)
out = self.mul2(out, self.mul_weight2)
return out return out






Loading…
Cancel
Save