|
|
@@ -25,13 +25,11 @@ class Net(Cell): |
|
|
def __init__(self, mul_weight, strategy1=None, strategy2=None): |
|
|
def __init__(self, mul_weight, strategy1=None, strategy2=None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.mul = P.Mul().shard(strategy1) |
|
|
self.mul = P.Mul().shard(strategy1) |
|
|
self.mul2 = P.Mul().shard(strategy1) |
|
|
|
|
|
self.dropout_do_mask = P.DropoutDoMask().shard(strategy2) |
|
|
self.dropout_do_mask = P.DropoutDoMask().shard(strategy2) |
|
|
self.dropout_gen_mask = P.DropoutGenMask() |
|
|
self.dropout_gen_mask = P.DropoutGenMask() |
|
|
self.get_shape = P.Shape() |
|
|
self.get_shape = P.Shape() |
|
|
self.cast = P.Cast() |
|
|
self.cast = P.Cast() |
|
|
self.mul_weight = Parameter(mul_weight, "w1") |
|
|
self.mul_weight = Parameter(mul_weight, "w1") |
|
|
self.mul_weight2 = Parameter(mul_weight, "w2") |
|
|
|
|
|
self.keep_prob = Tensor(0.9) |
|
|
self.keep_prob = Tensor(0.9) |
|
|
|
|
|
|
|
|
def construct(self, x, b): |
|
|
def construct(self, x, b): |
|
|
@@ -41,7 +39,6 @@ class Net(Cell): |
|
|
keep_prob = self.cast(self.keep_prob, dtype) |
|
|
keep_prob = self.cast(self.keep_prob, dtype) |
|
|
mask = self.dropout_gen_mask(shape, keep_prob) |
|
|
mask = self.dropout_gen_mask(shape, keep_prob) |
|
|
out = self.dropout_do_mask(out, mask, keep_prob) |
|
|
out = self.dropout_do_mask(out, mask, keep_prob) |
|
|
out = self.mul2(out, self.mul_weight2) |
|
|
|
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|