Browse Source

!11642 add dropout attr keep_prop

From: @jinyaohui
Reviewed-by: @kingxian,@hangangqiang
Signed-off-by: @kingxian
tags/v1.1.1
mindspore-ci-bot Gitee 5 years ago
parent
commit
8a4e0cc6f6
2 changed files with 20 additions and 3 deletions
  1. +17
    -2
      mindspore/core/ops/dropout.cc
  2. +3
    -1
      mindspore/core/ops/dropout.h

+ 17
- 2
mindspore/core/ops/dropout.cc View File

@@ -24,16 +24,31 @@

namespace mindspore {
namespace ops {
void Dropout::Init(const float ratio) { this->set_ratio(ratio); }
void Dropout::Init(const float ratio, const float keep_prob) {
this->set_ratio(ratio);
this->set_keep_prob(keep_prob);
}

void Dropout::set_ratio(const float ratio) {
CheckAndConvertUtils::CheckInRange<float>(kRatio, ratio, kIncludeRight, {0.0, 1.0}, this->name());
this->AddAttr(kKeepProb, MakeValue(ratio));
this->AddAttr(kRatio, MakeValue(ratio));
}

void Dropout::set_keep_prob(const float keep_prob) {
CheckAndConvertUtils::CheckInRange<float>(kKeepProb, keep_prob, kIncludeRight, {0.0, 1.0}, this->name());
this->AddAttr(kKeepProb, MakeValue(keep_prob));
}

float Dropout::get_ratio() const {
auto value_ptr = this->GetAttr(kRatio);
return GetValue<float>(value_ptr);
}

float Dropout::get_keep_prob() const {
auto value_ptr = this->GetAttr(kKeepProb);
return GetValue<float>(value_ptr);
}

AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);


+ 3
- 1
mindspore/core/ops/dropout.h View File

@@ -31,9 +31,11 @@ class Dropout : public PrimitiveC {
Dropout() : PrimitiveC(kNameDropout) {}
~Dropout() = default;
MS_DECLARE_PARENT(Dropout, PrimitiveC);
void Init(const float ratio = 0.5);
void Init(const float ratio = 0.5, const float keep_prob = 0.5);
void set_ratio(const float ratio);
void set_keep_prob(const float keep_prob);
float get_ratio() const;
float get_keep_prob() const;
};
AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);


Loading…
Cancel
Save