|
|
|
@@ -78,7 +78,7 @@ class Dropout(Cell): |
|
|
|
if keep_prob <= 0 or keep_prob > 1: |
|
|
|
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) |
|
|
|
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) |
|
|
|
self.keep_prob = Tensor(keep_prob) |
|
|
|
self.keep_prob = keep_prob |
|
|
|
self.seed0 = seed0 |
|
|
|
self.seed1 = seed1 |
|
|
|
self.dtype = dtype |
|
|
|
|