| @@ -62,7 +62,7 @@ AbstractBasePtr AddFusionInfer(const abstract::AnalysisEnginePtr &, const Primit | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AddFusion, prim::kPrimAdd, AddFusionInfer); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AddFusion, prim::kPrimAddFusion, AddFusionInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAddFusion, AddFusion); | REGISTER_PRIMITIVE_C(kNameAddFusion, AddFusion); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,18 +19,31 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| void DropoutGrad::Init(const float ratio) { set_ratio(ratio); } | |||||
| void DropoutGrad::Init(const float ratio, const float keep_prob) { | |||||
| this->set_ratio(ratio); | |||||
| this->set_keep_prob(keep_prob); | |||||
| } | |||||
| void DropoutGrad::set_ratio(const float ratio) { | void DropoutGrad::set_ratio(const float ratio) { | ||||
| CheckAndConvertUtils::CheckInRange<float>(kRatio, ratio, kIncludeRight, {0.0, 1.0}, this->name()); | CheckAndConvertUtils::CheckInRange<float>(kRatio, ratio, kIncludeRight, {0.0, 1.0}, this->name()); | ||||
| this->AddAttr(kRatio, MakeValue(ratio)); | this->AddAttr(kRatio, MakeValue(ratio)); | ||||
| } | } | ||||
| void DropoutGrad::set_keep_prob(const float keep_prob) { | |||||
| CheckAndConvertUtils::CheckInRange<float>(kKeepProb, keep_prob, kIncludeRight, {0.0, 1.0}, this->name()); | |||||
| this->AddAttr(kKeepProb, MakeValue(keep_prob)); | |||||
| } | |||||
| float DropoutGrad::get_ratio() const { | float DropoutGrad::get_ratio() const { | ||||
| auto value_ptr = GetAttr(kRatio); | auto value_ptr = GetAttr(kRatio); | ||||
| return GetValue<float>(value_ptr); | return GetValue<float>(value_ptr); | ||||
| } | } | ||||
| float DropoutGrad::get_keep_prob() const { | |||||
| auto value_ptr = GetAttr(kKeepProb); | |||||
| return GetValue<float>(value_ptr); | |||||
| } | |||||
| namespace { | namespace { | ||||
| abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, | abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| @@ -31,9 +31,11 @@ class DropoutGrad : public PrimitiveC { | |||||
| DropoutGrad() : PrimitiveC(kNameDropoutGrad) {} | DropoutGrad() : PrimitiveC(kNameDropoutGrad) {} | ||||
| ~DropoutGrad() = default; | ~DropoutGrad() = default; | ||||
| MS_DECLARE_PARENT(DropoutGrad, PrimitiveC); | MS_DECLARE_PARENT(DropoutGrad, PrimitiveC); | ||||
| void Init(const float ratio = 0.5); | |||||
| void Init(const float ratio = 0.5, const float keep_prob = 0.5); | |||||
| void set_ratio(const float ratio); | void set_ratio(const float ratio); | ||||
| void set_keep_prob(const float keep_prob); | |||||
| float get_ratio() const; | float get_ratio() const; | ||||
| float get_keep_prob() const; | |||||
| }; | }; | ||||
| AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -28,5 +28,34 @@ class TestAdd : public UT::Common { | |||||
| void SetUp() {} | void SetUp() {} | ||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| TEST_F(TestAdd, test_ops_add) { | |||||
| auto add = std::make_shared<Add>(); | |||||
| add->Init(); | |||||
| auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{1, 3}); | |||||
| auto tensor_y = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{1, 3}); | |||||
| MS_EXCEPTION_IF_NULL(tensor_x); | |||||
| MS_EXCEPTION_IF_NULL(tensor_y); | |||||
| auto add_abstract = add->Infer({tensor_x->ToAbstract(), tensor_y->ToAbstract()}); | |||||
| MS_EXCEPTION_IF_NULL(add_abstract); | |||||
| EXPECT_EQ(add_abstract->isa<abstract::AbstractTensor>(), true); | |||||
| auto shape_ptr = add_abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(shape_ptr); | |||||
| EXPECT_EQ(shape_ptr->isa<abstract::Shape>(), true); | |||||
| auto add_shape = shape_ptr->cast<abstract::ShapePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(add_shape); | |||||
| auto shape_vec = add_shape->shape(); | |||||
| auto type = add_abstract->BuildType(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| EXPECT_EQ(type->isa<TensorType>(), true); | |||||
| auto tensor_type = type->cast<TensorTypePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor_type); | |||||
| auto elem_type = tensor_type->element(); | |||||
| EXPECT_EQ(elem_type->type_id(), kNumberTypeFloat32); | |||||
| EXPECT_EQ(shape_vec.size(), 2); | |||||
| EXPECT_EQ(shape_vec[0], 1); | |||||
| EXPECT_EQ(shape_vec[1], 3); | |||||
| } | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||