diff --git a/mindspore/core/ops/dropout.cc b/mindspore/core/ops/dropout.cc index 08f7c4a9e1..adeb93f8af 100644 --- a/mindspore/core/ops/dropout.cc +++ b/mindspore/core/ops/dropout.cc @@ -24,16 +24,31 @@ namespace mindspore { namespace ops { -void Dropout::Init(const float ratio) { this->set_ratio(ratio); } +void Dropout::Init(const float ratio, const float keep_prob) { + this->set_ratio(ratio); + this->set_keep_prob(keep_prob); +} + void Dropout::set_ratio(const float ratio) { CheckAndConvertUtils::CheckInRange(kRatio, ratio, kIncludeRight, {0.0, 1.0}, this->name()); - this->AddAttr(kKeepProb, MakeValue(ratio)); + this->AddAttr(kRatio, MakeValue(ratio)); +} + +void Dropout::set_keep_prob(const float keep_prob) { + CheckAndConvertUtils::CheckInRange(kKeepProb, keep_prob, kIncludeRight, {0.0, 1.0}, this->name()); + this->AddAttr(kKeepProb, MakeValue(keep_prob)); } + float Dropout::get_ratio() const { auto value_ptr = this->GetAttr(kRatio); return GetValue(value_ptr); } +float Dropout::get_keep_prob() const { + auto value_ptr = this->GetAttr(kKeepProb); + return GetValue(value_ptr); +} + AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); diff --git a/mindspore/core/ops/dropout.h b/mindspore/core/ops/dropout.h index f673e26287..3da66b5cf6 100644 --- a/mindspore/core/ops/dropout.h +++ b/mindspore/core/ops/dropout.h @@ -31,9 +31,11 @@ class Dropout : public PrimitiveC { Dropout() : PrimitiveC(kNameDropout) {} ~Dropout() = default; MS_DECLARE_PARENT(Dropout, PrimitiveC); - void Init(const float ratio = 0.5); + void Init(const float ratio = 0.5, const float keep_prob = 0.5); void set_ratio(const float ratio); + void set_keep_prob(const float keep_prob); float get_ratio() const; + float get_keep_prob() const; }; AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args);