diff --git a/mindspore/core/ops/fusion/add_fusion.cc b/mindspore/core/ops/fusion/add_fusion.cc index c32ff9e1ac..cdd05e84be 100644 --- a/mindspore/core/ops/fusion/add_fusion.cc +++ b/mindspore/core/ops/fusion/add_fusion.cc @@ -62,7 +62,7 @@ AbstractBasePtr AddFusionInfer(const abstract::AnalysisEnginePtr &, const Primit return std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_EVAL_IMPL(AddFusion, prim::kPrimAdd, AddFusionInfer); +REGISTER_PRIMITIVE_EVAL_IMPL(AddFusion, prim::kPrimAddFusion, AddFusionInfer); REGISTER_PRIMITIVE_C(kNameAddFusion, AddFusion); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/dropout_grad.cc b/mindspore/core/ops/grad/dropout_grad.cc index 464f1d98d3..3c9e6af94f 100644 --- a/mindspore/core/ops/grad/dropout_grad.cc +++ b/mindspore/core/ops/grad/dropout_grad.cc @@ -19,18 +19,31 @@ namespace mindspore { namespace ops { -void DropoutGrad::Init(const float ratio) { set_ratio(ratio); } +void DropoutGrad::Init(const float ratio, const float keep_prob) { + this->set_ratio(ratio); + this->set_keep_prob(keep_prob); +} void DropoutGrad::set_ratio(const float ratio) { CheckAndConvertUtils::CheckInRange(kRatio, ratio, kIncludeRight, {0.0, 1.0}, this->name()); this->AddAttr(kRatio, MakeValue(ratio)); } +void DropoutGrad::set_keep_prob(const float keep_prob) { + CheckAndConvertUtils::CheckInRange(kKeepProb, keep_prob, kIncludeRight, {0.0, 1.0}, this->name()); + this->AddAttr(kKeepProb, MakeValue(keep_prob)); +} + float DropoutGrad::get_ratio() const { auto value_ptr = GetAttr(kRatio); return GetValue(value_ptr); } +float DropoutGrad::get_keep_prob() const { + auto value_ptr = GetAttr(kKeepProb); + return GetValue(value_ptr); +} + namespace { abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { diff --git a/mindspore/core/ops/grad/dropout_grad.h b/mindspore/core/ops/grad/dropout_grad.h index bc92d5a517..9881dd36a9 100644 --- a/mindspore/core/ops/grad/dropout_grad.h +++ b/mindspore/core/ops/grad/dropout_grad.h @@ -31,9 +31,11 @@ class DropoutGrad : public PrimitiveC { DropoutGrad() : PrimitiveC(kNameDropoutGrad) {} ~DropoutGrad() = default; MS_DECLARE_PARENT(DropoutGrad, PrimitiveC); - void Init(const float ratio = 0.5); + void Init(const float ratio = 0.5, const float keep_prob = 0.5); void set_ratio(const float ratio); + void set_keep_prob(const float keep_prob); float get_ratio() const; + float get_keep_prob() const; }; AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/tests/ut/cpp/ops/test_ops_add.cc b/tests/ut/cpp/ops/test_ops_add.cc index 41a5592c6b..ce22771179 100644 --- a/tests/ut/cpp/ops/test_ops_add.cc +++ b/tests/ut/cpp/ops/test_ops_add.cc @@ -28,5 +28,34 @@ class TestAdd : public UT::Common { void SetUp() {} void TearDown() {} }; + +TEST_F(TestAdd, test_ops_add) { + auto add = std::make_shared(); + add->Init(); + auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector{1, 3}); + auto tensor_y = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector{1, 3}); + MS_EXCEPTION_IF_NULL(tensor_x); + MS_EXCEPTION_IF_NULL(tensor_y); + auto add_abstract = add->Infer({tensor_x->ToAbstract(), tensor_y->ToAbstract()}); + MS_EXCEPTION_IF_NULL(add_abstract); + EXPECT_EQ(add_abstract->isa(), true); + auto shape_ptr = add_abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(shape_ptr); + EXPECT_EQ(shape_ptr->isa(), true); + auto add_shape = shape_ptr->cast(); + MS_EXCEPTION_IF_NULL(add_shape); + auto shape_vec = add_shape->shape(); + auto type = add_abstract->BuildType(); + MS_EXCEPTION_IF_NULL(type); + EXPECT_EQ(type->isa(), true); + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto elem_type = tensor_type->element(); + EXPECT_EQ(elem_type->type_id(), kNumberTypeFloat32); + EXPECT_EQ(shape_vec.size(), 2); + EXPECT_EQ(shape_vec[0], 1); + EXPECT_EQ(shape_vec[1], 3); +} + } // namespace ops } // namespace mindspore