diff --git a/mindspore/core/ops/grad/activation_grad.cc b/mindspore/core/ops/grad/activation_grad.cc index 961394ec6f..09df3fd2b5 100644 --- a/mindspore/core/ops/grad/activation_grad.cc +++ b/mindspore/core/ops/grad/activation_grad.cc @@ -28,16 +28,16 @@ namespace mindspore { namespace ops { void ActivationGrad::Init(const ActivationType &type, const float alpha) { - this->set_type(type); + this->set_activation_type(type); this->set_alpha(alpha); } -void ActivationGrad::set_type(const ActivationType &type) { +void ActivationGrad::set_activation_type(const ActivationType &type) { int64_t swi = type; this->AddAttr(kActivationType, MakeValue(swi)); } -ActivationType ActivationGrad::get_type() const { +ActivationType ActivationGrad::get_activation_type() const { auto value_ptr = GetAttr(kActivationType); return ActivationType(GetValue(value_ptr)); } diff --git a/mindspore/core/ops/grad/activation_grad.h b/mindspore/core/ops/grad/activation_grad.h index 2fb73eba3e..e8f3675211 100644 --- a/mindspore/core/ops/grad/activation_grad.h +++ b/mindspore/core/ops/grad/activation_grad.h @@ -33,10 +33,10 @@ class ActivationGrad : public PrimitiveC { ~ActivationGrad() = default; MS_DECLARE_PARENT(ActivationGrad, PrimitiveC); void Init(const ActivationType &type = NO_ACTIVATION, const float alpha = 0.2); - void set_type(const ActivationType &type); + void set_activation_type(const ActivationType &type); void set_alpha(const float alpha); - ActivationType get_type() const; + ActivationType get_activation_type() const; float get_alpha() const; }; } // namespace ops diff --git a/mindspore/core/ops/grad/avg_pool_grad.cc b/mindspore/core/ops/grad/avg_pool_grad.cc index 8b64867b44..5752bd5378 100644 --- a/mindspore/core/ops/grad/avg_pool_grad.cc +++ b/mindspore/core/ops/grad/avg_pool_grad.cc @@ -32,6 +32,7 @@ AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim return std::make_shared(element, origin_input_shape); } +REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGrad, prim::kPrimAvgPoolGrad, AvgPoolGradInfer); REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/bias_grad.cc b/mindspore/core/ops/grad/bias_add_grad.cc similarity index 67% rename from mindspore/core/ops/grad/bias_grad.cc rename to mindspore/core/ops/grad/bias_add_grad.cc index dd2e4fd35e..8c9bcea071 100644 --- a/mindspore/core/ops/grad/bias_grad.cc +++ b/mindspore/core/ops/grad/bias_add_grad.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ops/grad/bias_grad.h" +#include "ops/grad/bias_add_grad.h" #include #include #include @@ -26,10 +26,22 @@ namespace mindspore { namespace ops { -AbstractBasePtr BiasGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { +void BiasAddGrad::Init(const Format format) { this->set_format(format); } + +void BiasAddGrad::set_format(const Format format) { + int64_t f = format; + AddAttr(kFormat, MakeValue(f)); +} + +Format BiasAddGrad::get_format() const { + auto value_ptr = GetAttr(kFormat); + return Format(GetValue(value_ptr)); +} + +AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto bias_prim = primitive->cast(); + auto bias_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(bias_prim); auto prim_name = bias_prim->name(); CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name); @@ -46,6 +58,8 @@ AbstractBasePtr BiasGradInfer(const abstract::AnalysisEnginePtr &, const Primiti return std::make_shared(intype, inshape); } -REGISTER_PRIMITIVE_C(kNameBiasGrad, BiasGrad); + +REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer); +REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/bias_grad.h b/mindspore/core/ops/grad/bias_add_grad.h similarity index 56% rename from mindspore/core/ops/grad/bias_grad.h rename to mindspore/core/ops/grad/bias_add_grad.h index 6229e72bc3..7399a77cca 100644 --- a/mindspore/core/ops/grad/bias_grad.h +++ b/mindspore/core/ops/grad/bias_add_grad.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_OPS_BIAS_GRAD_H_ -#define MINDSPORE_CORE_OPS_BIAS_GRAD_H_ +#ifndef MINDSPORE_CORE_OPS_BIAS_ADD_GRAD_H_ +#define MINDSPORE_CORE_OPS_BIAS_ADD_GRAD_H_ #include #include #include @@ -26,18 +26,20 @@ namespace mindspore { namespace ops { -constexpr auto kNameBiasGrad = "BiasGrad"; -class BiasGrad : public PrimitiveC { +constexpr auto kNameBiasAddGrad = "BiasAddGrad"; +class BiasAddGrad : public PrimitiveC { public: - BiasGrad() : PrimitiveC(kNameBiasGrad) {} - ~BiasGrad() = default; - MS_DECLARE_PARENT(BiasGrad, PrimitiveC); - void Init(); + BiasAddGrad() : PrimitiveC(kNameBiasAddGrad) {} + ~BiasAddGrad() = default; + MS_DECLARE_PARENT(BiasAddGrad, PrimitiveC); + void Init(const Format format); + void set_format(const Format format); + Format get_format() const; }; -AbstractBasePtr BiasGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimBiasGradPtr = std::shared_ptr; +AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimBiasAddGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore -#endif // MINDSPORE_CORE_OPS_BIAS_GRAD_H_ +#endif // MINDSPORE_CORE_OPS_BIAS_ADD_GRAD_H_ diff --git a/mindspore/core/ops/grad/max_pool_grad.cc b/mindspore/core/ops/grad/max_pool_grad.cc index 0c71271b2d..ae6c5a4a29 100644 --- a/mindspore/core/ops/grad/max_pool_grad.cc +++ b/mindspore/core/ops/grad/max_pool_grad.cc @@ -19,40 +19,11 @@ namespace mindspore { namespace ops { -void MaxPoolGrad::Init(const std::vector &kernel_size, const std::vector &strides, - const PadMode &pad_mode, const Format &data_format) { - this->set_data_format(data_format); - this->set_kernel_size(kernel_size); - this->set_strides(strides); - this->set_pad_mode(pad_mode); -} - -void MaxPoolGrad::set_data_format(const Format &data_format) { - int64_t swi = data_format; - this->AddAttr(kFormat, MakeValue(swi)); -} - -Format MaxPoolGrad::get_data_format() const { - auto value_ptr = GetAttr(kFormat); - return Format(GetValue(value_ptr)); -} - -void MaxPoolGrad::set_kernel_size(const std::vector &kernel_size) { - std::vector k_size = _grad_check_vector(kSize, kernel_size, this->name()); - k_size = this->get_data_format() == NCHW ? k_size : std::vector{k_size[0], k_size[2], k_size[3], k_size[1]}; - this->AddAttr(kSize, MakeValue(k_size)); -} - -void MaxPoolGrad::set_strides(const std::vector &strides) { - std::vector stride_ = _grad_check_vector(kStrides, strides, this->name()); - stride_ = - this->get_data_format() == NCHW ? stride_ : std::vector{stride_[0], stride_[2], stride_[3], stride_[1]}; - this->AddAttr(kStrides, MakeValue(stride_)); -} - AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - auto op_name = primitive->name(); + auto MaxPoolGrad_prim = primitive->cast(); + MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim); + auto op_name = MaxPoolGrad_prim->name(); MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); auto tensor_type = input_args[0]->BuildType()->cast(); @@ -60,6 +31,8 @@ AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim auto element = tensor_type->element(); return std::make_shared(element, x1_shape); } + +REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGrad, prim::kPrimMaxPoolGrad, MaxPoolGradInfer); REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/max_pool_grad.h b/mindspore/core/ops/grad/max_pool_grad.h index 0d45beee22..18e0b497aa 100644 --- a/mindspore/core/ops/grad/max_pool_grad.h +++ b/mindspore/core/ops/grad/max_pool_grad.h @@ -33,12 +33,6 @@ class MaxPoolGrad : public PoolGrad { MaxPoolGrad() : PoolGrad(kNameMaxPoolGrad) { InitIOName({"x_origin", "out_origin", "grad"}, {"output"}); } ~MaxPoolGrad() = default; MS_DECLARE_PARENT(MaxPoolGrad, PoolGrad); - void Init(const std::vector &kernel_size = {1}, const std::vector &strides = {1}, - const PadMode &pad_mode = VALID, const Format &data_format = NCHW); - void set_kernel_size(const std::vector &kernel_size); - void set_strides(const std::vector &strides); - void set_data_format(const Format &data_format); - Format get_data_format() const; }; AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/grad/pool_grad.cc b/mindspore/core/ops/grad/pool_grad.cc index 8bb965aa50..436bc8cb97 100644 --- a/mindspore/core/ops/grad/pool_grad.cc +++ b/mindspore/core/ops/grad/pool_grad.cc @@ -47,10 +47,11 @@ std::vector PoolGrad::_grad_check_vector(std::string arg_name, std::vec } void PoolGrad::Init(const std::vector &kernel_size, const std::vector &strides, - const PadMode &pad_mode) { + const PadMode &pad_mode, const Format &format) { this->set_kernel_size(kernel_size); this->set_strides(strides); this->set_pad_mode(pad_mode); + this->set_format(format); } void PoolGrad::set_kernel_size(const std::vector &kernel_size) { @@ -68,6 +69,11 @@ void PoolGrad::set_pad_mode(const PadMode &pad_mode) { this->AddAttr(kPadMode, MakeValue(swi)); } +void PoolGrad::set_format(const Format &format) { + int64_t swi = format; + this->AddAttr(kFormat, MakeValue(swi)); +} + std::vector PoolGrad::get_kernel_size() const { auto value_ptr = GetAttr(kSize); return GetValue>(value_ptr); @@ -82,6 +88,12 @@ PadMode PoolGrad::get_pad_mode() const { auto value_ptr = GetAttr(kPadMode); return PadMode(GetValue(value_ptr)); } + +Format PoolGrad::get_format() const { + auto value_ptr = GetAttr(kFormat); + return Format(GetValue(value_ptr)); +} + REGISTER_PRIMITIVE_C(kNamePoolGrad, PoolGrad); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/pool_grad.h b/mindspore/core/ops/grad/pool_grad.h index 2127d7069a..b6bd52ae29 100644 --- a/mindspore/core/ops/grad/pool_grad.h +++ b/mindspore/core/ops/grad/pool_grad.h @@ -36,14 +36,16 @@ class PoolGrad : public PrimitiveC { ~PoolGrad() = default; MS_DECLARE_PARENT(PoolGrad, PrimitiveC); virtual void Init(const std::vector &kernel_size = {1}, const std::vector &strides = {1}, - const PadMode &pad_mode = VALID); + const PadMode &pad_mode = VALID, const Format &format = NCHW); virtual void set_kernel_size(const std::vector &kernel_size); virtual void set_strides(const std::vector &strides); void set_pad_mode(const PadMode &pad_mode); + void set_format(const Format &format); std::vector get_kernel_size() const; std::vector get_strides() const; PadMode get_pad_mode() const; + Format get_format() const; std::vector _grad_check_vector(const std::string arg_name, const std::vector arg_val, const std::string op_name); }; diff --git a/mindspore/core/ops/lrn.cc b/mindspore/core/ops/lrn.cc index f6897445ac..a86dd16c89 100644 --- a/mindspore/core/ops/lrn.cc +++ b/mindspore/core/ops/lrn.cc @@ -26,46 +26,46 @@ namespace mindspore { namespace ops { -void Lrn::set_depth_radius(const int64_t depth_radius) { +void LRN::set_depth_radius(const int64_t depth_radius) { CheckAndConvertUtils::CheckInteger(kDepthRadius, depth_radius, kGreaterEqual, 0, this->name()); this->AddAttr(kDepthRadius, MakeValue(depth_radius)); } -int64_t Lrn::get_depth_radius() const { +int64_t LRN::get_depth_radius() const { auto value_ptr = GetAttr(kDepthRadius); return GetValue(value_ptr); } -void Lrn::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); } +void LRN::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); } -float Lrn::get_bias() const { +float LRN::get_bias() const { auto value_ptr = GetAttr(kBias); return GetValue(value_ptr); } -void Lrn::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); } +void LRN::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); } -float Lrn::get_alpha() const { +float LRN::get_alpha() const { auto value_ptr = GetAttr(kAlpha); return GetValue(value_ptr); } -void Lrn::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); } +void LRN::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); } -float Lrn::get_beta() const { +float LRN::get_beta() const { auto value_ptr = GetAttr(kBeta); return GetValue(value_ptr); } -void Lrn::set_norm_region(const std::string &norm_region) { +void LRN::set_norm_region(const std::string &norm_region) { CheckAndConvertUtils::CheckString(kNormRegion, norm_region, {"ACROSS_CHANNELS"}, this->name()); this->AddAttr(kNormRegion, MakeValue(norm_region)); } -std::string Lrn::get_norm_region() const { +std::string LRN::get_norm_region() const { auto value_ptr = GetAttr(kNormRegion); return GetValue(value_ptr); } -void Lrn::Init(const int64_t depth_radius, const float bias, const float alpha, const float beta, +void LRN::Init(const int64_t depth_radius, const float bias, const float alpha, const float beta, const std::string &norm_region) { this->set_depth_radius(depth_radius); this->set_bias(bias); @@ -102,6 +102,7 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(InferType(primitive, input_args), InferShape(primitive, input_args)->shape()); } -REGISTER_PRIMITIVE_C(kNameLrn, Lrn); +REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer); +REGISTER_PRIMITIVE_C(kNameLRN, LRN); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/lrn.h b/mindspore/core/ops/lrn.h index ad9aa45252..12b02bff75 100644 --- a/mindspore/core/ops/lrn.h +++ b/mindspore/core/ops/lrn.h @@ -26,12 +26,12 @@ namespace mindspore { namespace ops { -constexpr auto kNameLrn = "Lrn"; -class Lrn : public PrimitiveC { +constexpr auto kNameLRN = "Lrn"; +class LRN : public PrimitiveC { public: - Lrn() : PrimitiveC(kNameLrn) { InitIOName({"x"}, {"y"}); } - ~Lrn() = default; - MS_DECLARE_PARENT(Lrn, PrimitiveC); + LRN() : PrimitiveC(kNameLRN) { InitIOName({"x"}, {"y"}); } + ~LRN() = default; + MS_DECLARE_PARENT(LRN, PrimitiveC); void Init(const int64_t depth_radius = 5, const float bias = 1.0, const float alpha = 1.0, const float beta = 0.5, const std::string &norm_region = "ACROSS_CHANNELS"); void set_depth_radius(const int64_t depth_radius); @@ -47,7 +47,7 @@ class Lrn : public PrimitiveC { }; AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); -using PrimLrn = std::shared_ptr; +using PrimLrn = std::shared_ptr; } // namespace ops } // namespace mindspore #endif // MINDSPORE_CORE_OPS_LRN_H_ diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 47d7f7eac8..39a4076fb1 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -226,7 +226,6 @@ constexpr auto kReduceToEnd = "reduce_to_end"; constexpr auto kResetAfter = "reset_after"; constexpr auto kCoeff = "coeff"; constexpr auto kIsDepthWise = "is_depth_wise"; -constexpr auto kIsDepthWiseNative = "is_depth_wise_native"; constexpr auto kZoneoutCell = "zoneout_cell"; constexpr auto kZoneoutHidden = "zoneout_hidden"; diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 92f1fcd9d6..a858b51b01 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -301,16 +301,16 @@ std::vector CheckAndConvertUtils::CheckPositiveVector(const std::string raise_message(); } } - if (arg_value.size() == 1) { - return ret_four ? std::vector{1, 1, arg_value[0], arg_value[0]} - : std::vector{arg_value[0], arg_value[0]}; - } - if (arg_value.size() == 2) { - return ret_four ? std::vector{1, 1, arg_value[0], arg_value[1]} : arg_value; - } else if (arg_value.size() == 4 && allow_four) { - return ret_four ? arg_value : std::vector{arg_value[2], arg_value[3]}; - } - raise_message(); + // if (arg_value.size() == 1) { + // return ret_four ? std::vector{1, 1, arg_value[0], arg_value[0]} + // : std::vector{arg_value[0], arg_value[0]}; + // } + // if (arg_value.size() == 2) { + // return ret_four ? std::vector{1, 1, arg_value[0], arg_value[1]} : arg_value; + // } else if (arg_value.size() == 4 && allow_four) { + // return ret_four ? arg_value : std::vector{arg_value[2], arg_value[3]}; + // } + // raise_message(); return arg_value; }