Browse Source

add dropout attr keep_prop

tags/v1.1.1
jinyaohui 5 years ago
parent
commit
1305e08593
4 changed files with 47 additions and 3 deletions
  1. +1
    -1
      mindspore/core/ops/fusion/add_fusion.cc
  2. +14
    -1
      mindspore/core/ops/grad/dropout_grad.cc
  3. +3
    -1
      mindspore/core/ops/grad/dropout_grad.h
  4. +29
    -0
      tests/ut/cpp/ops/test_ops_add.cc

+ 1
- 1
mindspore/core/ops/fusion/add_fusion.cc View File

@@ -62,7 +62,7 @@ AbstractBasePtr AddFusionInfer(const abstract::AnalysisEnginePtr &, const Primit
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(AddFusion, prim::kPrimAdd, AddFusionInfer);
REGISTER_PRIMITIVE_EVAL_IMPL(AddFusion, prim::kPrimAddFusion, AddFusionInfer);
REGISTER_PRIMITIVE_C(kNameAddFusion, AddFusion);
} // namespace ops
} // namespace mindspore

+ 14
- 1
mindspore/core/ops/grad/dropout_grad.cc View File

@@ -19,18 +19,31 @@

namespace mindspore {
namespace ops {
void DropoutGrad::Init(const float ratio) { set_ratio(ratio); }
void DropoutGrad::Init(const float ratio, const float keep_prob) {
this->set_ratio(ratio);
this->set_keep_prob(keep_prob);
}

void DropoutGrad::set_ratio(const float ratio) {
CheckAndConvertUtils::CheckInRange<float>(kRatio, ratio, kIncludeRight, {0.0, 1.0}, this->name());
this->AddAttr(kRatio, MakeValue(ratio));
}

void DropoutGrad::set_keep_prob(const float keep_prob) {
CheckAndConvertUtils::CheckInRange<float>(kKeepProb, keep_prob, kIncludeRight, {0.0, 1.0}, this->name());
this->AddAttr(kKeepProb, MakeValue(keep_prob));
}

float DropoutGrad::get_ratio() const {
auto value_ptr = GetAttr(kRatio);
return GetValue<float>(value_ptr);
}

float DropoutGrad::get_keep_prob() const {
auto value_ptr = GetAttr(kKeepProb);
return GetValue<float>(value_ptr);
}

namespace {
abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {


+ 3
- 1
mindspore/core/ops/grad/dropout_grad.h View File

@@ -31,9 +31,11 @@ class DropoutGrad : public PrimitiveC {
DropoutGrad() : PrimitiveC(kNameDropoutGrad) {}
~DropoutGrad() = default;
MS_DECLARE_PARENT(DropoutGrad, PrimitiveC);
void Init(const float ratio = 0.5);
void Init(const float ratio = 0.5, const float keep_prob = 0.5);
void set_ratio(const float ratio);
void set_keep_prob(const float keep_prob);
float get_ratio() const;
float get_keep_prob() const;
};

AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 29
- 0
tests/ut/cpp/ops/test_ops_add.cc View File

@@ -28,5 +28,34 @@ class TestAdd : public UT::Common {
void SetUp() {}
void TearDown() {}
};

TEST_F(TestAdd, test_ops_add) {
auto add = std::make_shared<Add>();
add->Init();
auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{1, 3});
auto tensor_y = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{1, 3});
MS_EXCEPTION_IF_NULL(tensor_x);
MS_EXCEPTION_IF_NULL(tensor_y);
auto add_abstract = add->Infer({tensor_x->ToAbstract(), tensor_y->ToAbstract()});
MS_EXCEPTION_IF_NULL(add_abstract);
EXPECT_EQ(add_abstract->isa<abstract::AbstractTensor>(), true);
auto shape_ptr = add_abstract->BuildShape();
MS_EXCEPTION_IF_NULL(shape_ptr);
EXPECT_EQ(shape_ptr->isa<abstract::Shape>(), true);
auto add_shape = shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(add_shape);
auto shape_vec = add_shape->shape();
auto type = add_abstract->BuildType();
MS_EXCEPTION_IF_NULL(type);
EXPECT_EQ(type->isa<TensorType>(), true);
auto tensor_type = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto elem_type = tensor_type->element();
EXPECT_EQ(elem_type->type_id(), kNumberTypeFloat32);
EXPECT_EQ(shape_vec.size(), 2);
EXPECT_EQ(shape_vec[0], 1);
EXPECT_EQ(shape_vec[1], 3);
}

} // namespace ops
} // namespace mindspore

Loading…
Cancel
Save