|
|
|
@@ -24,16 +24,31 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace ops { |
|
|
|
void Dropout::Init(const float ratio) { this->set_ratio(ratio); } |
|
|
|
void Dropout::Init(const float ratio, const float keep_prob) { |
|
|
|
this->set_ratio(ratio); |
|
|
|
this->set_keep_prob(keep_prob); |
|
|
|
} |
|
|
|
|
|
|
|
void Dropout::set_ratio(const float ratio) { |
|
|
|
CheckAndConvertUtils::CheckInRange<float>(kRatio, ratio, kIncludeRight, {0.0, 1.0}, this->name()); |
|
|
|
this->AddAttr(kKeepProb, MakeValue(ratio)); |
|
|
|
this->AddAttr(kRatio, MakeValue(ratio)); |
|
|
|
} |
|
|
|
|
|
|
|
void Dropout::set_keep_prob(const float keep_prob) { |
|
|
|
CheckAndConvertUtils::CheckInRange<float>(kKeepProb, keep_prob, kIncludeRight, {0.0, 1.0}, this->name()); |
|
|
|
this->AddAttr(kKeepProb, MakeValue(keep_prob)); |
|
|
|
} |
|
|
|
|
|
|
|
float Dropout::get_ratio() const { |
|
|
|
auto value_ptr = this->GetAttr(kRatio); |
|
|
|
return GetValue<float>(value_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
float Dropout::get_keep_prob() const { |
|
|
|
auto value_ptr = this->GetAttr(kKeepProb); |
|
|
|
return GetValue<float>(value_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
|