| @@ -28,16 +28,16 @@ namespace mindspore { | |||||
| namespace ops { | namespace ops { | ||||
| void ActivationGrad::Init(const ActivationType &type, const float alpha) { | void ActivationGrad::Init(const ActivationType &type, const float alpha) { | ||||
| this->set_type(type); | |||||
| this->set_activation_type(type); | |||||
| this->set_alpha(alpha); | this->set_alpha(alpha); | ||||
| } | } | ||||
| void ActivationGrad::set_type(const ActivationType &type) { | |||||
| void ActivationGrad::set_activation_type(const ActivationType &type) { | |||||
| int64_t swi = type; | int64_t swi = type; | ||||
| this->AddAttr(kActivationType, MakeValue(swi)); | this->AddAttr(kActivationType, MakeValue(swi)); | ||||
| } | } | ||||
| ActivationType ActivationGrad::get_type() const { | |||||
| ActivationType ActivationGrad::get_activation_type() const { | |||||
| auto value_ptr = GetAttr(kActivationType); | auto value_ptr = GetAttr(kActivationType); | ||||
| return ActivationType(GetValue<int64_t>(value_ptr)); | return ActivationType(GetValue<int64_t>(value_ptr)); | ||||
| } | } | ||||
| @@ -33,10 +33,10 @@ class ActivationGrad : public PrimitiveC { | |||||
| ~ActivationGrad() = default; | ~ActivationGrad() = default; | ||||
| MS_DECLARE_PARENT(ActivationGrad, PrimitiveC); | MS_DECLARE_PARENT(ActivationGrad, PrimitiveC); | ||||
| void Init(const ActivationType &type = NO_ACTIVATION, const float alpha = 0.2); | void Init(const ActivationType &type = NO_ACTIVATION, const float alpha = 0.2); | ||||
| void set_type(const ActivationType &type); | |||||
| void set_activation_type(const ActivationType &type); | |||||
| void set_alpha(const float alpha); | void set_alpha(const float alpha); | ||||
| ActivationType get_type() const; | |||||
| ActivationType get_activation_type() const; | |||||
| float get_alpha() const; | float get_alpha() const; | ||||
| }; | }; | ||||
| } // namespace ops | } // namespace ops | ||||
| @@ -32,6 +32,7 @@ AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape); | return std::make_shared<abstract::AbstractTensor>(element, origin_input_shape); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(AvgPoolGrad, prim::kPrimAvgPoolGrad, AvgPoolGradInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); | REGISTER_PRIMITIVE_C(kNameAvgPoolGrad, AvgPoolGrad); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ops/grad/bias_grad.h" | |||||
| #include "ops/grad/bias_add_grad.h" | |||||
| #include <string> | #include <string> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -26,10 +26,22 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| AbstractBasePtr BiasGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| void BiasAddGrad::Init(const Format format) { this->set_format(format); } | |||||
| void BiasAddGrad::set_format(const Format format) { | |||||
| int64_t f = format; | |||||
| AddAttr(kFormat, MakeValue(f)); | |||||
| } | |||||
| Format BiasAddGrad::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args) { | |||||
| MS_EXCEPTION_IF_NULL(primitive); | MS_EXCEPTION_IF_NULL(primitive); | ||||
| auto bias_prim = primitive->cast<PrimBiasGradPtr>(); | |||||
| auto bias_prim = primitive->cast<PrimBiasAddGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(bias_prim); | MS_EXCEPTION_IF_NULL(bias_prim); | ||||
| auto prim_name = bias_prim->name(); | auto prim_name = bias_prim->name(); | ||||
| CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name); | CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name); | ||||
| @@ -46,6 +58,8 @@ AbstractBasePtr BiasGradInfer(const abstract::AnalysisEnginePtr &, const Primiti | |||||
| return std::make_shared<abstract::AbstractTensor>(intype, inshape); | return std::make_shared<abstract::AbstractTensor>(intype, inshape); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameBiasGrad, BiasGrad); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CORE_OPS_BIAS_GRAD_H_ | |||||
| #define MINDSPORE_CORE_OPS_BIAS_GRAD_H_ | |||||
| #ifndef MINDSPORE_CORE_OPS_BIAS_ADD_GRAD_H_ | |||||
| #define MINDSPORE_CORE_OPS_BIAS_ADD_GRAD_H_ | |||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| @@ -26,18 +26,20 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameBiasGrad = "BiasGrad"; | |||||
| class BiasGrad : public PrimitiveC { | |||||
| constexpr auto kNameBiasAddGrad = "BiasAddGrad"; | |||||
| class BiasAddGrad : public PrimitiveC { | |||||
| public: | public: | ||||
| BiasGrad() : PrimitiveC(kNameBiasGrad) {} | |||||
| ~BiasGrad() = default; | |||||
| MS_DECLARE_PARENT(BiasGrad, PrimitiveC); | |||||
| void Init(); | |||||
| BiasAddGrad() : PrimitiveC(kNameBiasAddGrad) {} | |||||
| ~BiasAddGrad() = default; | |||||
| MS_DECLARE_PARENT(BiasAddGrad, PrimitiveC); | |||||
| void Init(const Format format); | |||||
| void set_format(const Format format); | |||||
| Format get_format() const; | |||||
| }; | }; | ||||
| AbstractBasePtr BiasGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBiasGradPtr = std::shared_ptr<BiasGrad>; | |||||
| AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const std::vector<AbstractBasePtr> &input_args); | |||||
| using PrimBiasAddGradPtr = std::shared_ptr<BiasAddGrad>; | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_OPS_BIAS_GRAD_H_ | |||||
| #endif // MINDSPORE_CORE_OPS_BIAS_ADD_GRAD_H_ | |||||
| @@ -19,40 +19,11 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| void MaxPoolGrad::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &strides, | |||||
| const PadMode &pad_mode, const Format &data_format) { | |||||
| this->set_data_format(data_format); | |||||
| this->set_kernel_size(kernel_size); | |||||
| this->set_strides(strides); | |||||
| this->set_pad_mode(pad_mode); | |||||
| } | |||||
| void MaxPoolGrad::set_data_format(const Format &data_format) { | |||||
| int64_t swi = data_format; | |||||
| this->AddAttr(kFormat, MakeValue(swi)); | |||||
| } | |||||
| Format MaxPoolGrad::get_data_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| void MaxPoolGrad::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||||
| std::vector<int64_t> k_size = _grad_check_vector(kSize, kernel_size, this->name()); | |||||
| k_size = this->get_data_format() == NCHW ? k_size : std::vector<int64_t>{k_size[0], k_size[2], k_size[3], k_size[1]}; | |||||
| this->AddAttr(kSize, MakeValue(k_size)); | |||||
| } | |||||
| void MaxPoolGrad::set_strides(const std::vector<int64_t> &strides) { | |||||
| std::vector<int64_t> stride_ = _grad_check_vector(kStrides, strides, this->name()); | |||||
| stride_ = | |||||
| this->get_data_format() == NCHW ? stride_ : std::vector<int64_t>{stride_[0], stride_[2], stride_[3], stride_[1]}; | |||||
| this->AddAttr(kStrides, MakeValue(stride_)); | |||||
| } | |||||
| AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args) { | const std::vector<AbstractBasePtr> &input_args) { | ||||
| auto op_name = primitive->name(); | |||||
| auto MaxPoolGrad_prim = primitive->cast<PrimMaxPoolGradPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim); | |||||
| auto op_name = MaxPoolGrad_prim->name(); | |||||
| MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); | ||||
| auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); | ||||
| auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); | ||||
| @@ -60,6 +31,8 @@ AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const Prim | |||||
| auto element = tensor_type->element(); | auto element = tensor_type->element(); | ||||
| return std::make_shared<abstract::AbstractTensor>(element, x1_shape); | return std::make_shared<abstract::AbstractTensor>(element, x1_shape); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_EVAL_IMPL(MaxPoolGrad, prim::kPrimMaxPoolGrad, MaxPoolGradInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); | REGISTER_PRIMITIVE_C(kNameMaxPoolGrad, MaxPoolGrad); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,12 +33,6 @@ class MaxPoolGrad : public PoolGrad { | |||||
| MaxPoolGrad() : PoolGrad(kNameMaxPoolGrad) { InitIOName({"x_origin", "out_origin", "grad"}, {"output"}); } | MaxPoolGrad() : PoolGrad(kNameMaxPoolGrad) { InitIOName({"x_origin", "out_origin", "grad"}, {"output"}); } | ||||
| ~MaxPoolGrad() = default; | ~MaxPoolGrad() = default; | ||||
| MS_DECLARE_PARENT(MaxPoolGrad, PoolGrad); | MS_DECLARE_PARENT(MaxPoolGrad, PoolGrad); | ||||
| void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &strides = {1}, | |||||
| const PadMode &pad_mode = VALID, const Format &data_format = NCHW); | |||||
| void set_kernel_size(const std::vector<int64_t> &kernel_size); | |||||
| void set_strides(const std::vector<int64_t> &strides); | |||||
| void set_data_format(const Format &data_format); | |||||
| Format get_data_format() const; | |||||
| }; | }; | ||||
| AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -47,10 +47,11 @@ std::vector<int64_t> PoolGrad::_grad_check_vector(std::string arg_name, std::vec | |||||
| } | } | ||||
| void PoolGrad::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &strides, | void PoolGrad::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &strides, | ||||
| const PadMode &pad_mode) { | |||||
| const PadMode &pad_mode, const Format &format) { | |||||
| this->set_kernel_size(kernel_size); | this->set_kernel_size(kernel_size); | ||||
| this->set_strides(strides); | this->set_strides(strides); | ||||
| this->set_pad_mode(pad_mode); | this->set_pad_mode(pad_mode); | ||||
| this->set_format(format); | |||||
| } | } | ||||
| void PoolGrad::set_kernel_size(const std::vector<int64_t> &kernel_size) { | void PoolGrad::set_kernel_size(const std::vector<int64_t> &kernel_size) { | ||||
| @@ -68,6 +69,11 @@ void PoolGrad::set_pad_mode(const PadMode &pad_mode) { | |||||
| this->AddAttr(kPadMode, MakeValue(swi)); | this->AddAttr(kPadMode, MakeValue(swi)); | ||||
| } | } | ||||
| void PoolGrad::set_format(const Format &format) { | |||||
| int64_t swi = format; | |||||
| this->AddAttr(kFormat, MakeValue(swi)); | |||||
| } | |||||
| std::vector<int64_t> PoolGrad::get_kernel_size() const { | std::vector<int64_t> PoolGrad::get_kernel_size() const { | ||||
| auto value_ptr = GetAttr(kSize); | auto value_ptr = GetAttr(kSize); | ||||
| return GetValue<std::vector<int64_t>>(value_ptr); | return GetValue<std::vector<int64_t>>(value_ptr); | ||||
| @@ -82,6 +88,12 @@ PadMode PoolGrad::get_pad_mode() const { | |||||
| auto value_ptr = GetAttr(kPadMode); | auto value_ptr = GetAttr(kPadMode); | ||||
| return PadMode(GetValue<int64_t>(value_ptr)); | return PadMode(GetValue<int64_t>(value_ptr)); | ||||
| } | } | ||||
| Format PoolGrad::get_format() const { | |||||
| auto value_ptr = GetAttr(kFormat); | |||||
| return Format(GetValue<int64_t>(value_ptr)); | |||||
| } | |||||
| REGISTER_PRIMITIVE_C(kNamePoolGrad, PoolGrad); | REGISTER_PRIMITIVE_C(kNamePoolGrad, PoolGrad); | ||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,14 +36,16 @@ class PoolGrad : public PrimitiveC { | |||||
| ~PoolGrad() = default; | ~PoolGrad() = default; | ||||
| MS_DECLARE_PARENT(PoolGrad, PrimitiveC); | MS_DECLARE_PARENT(PoolGrad, PrimitiveC); | ||||
| virtual void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &strides = {1}, | virtual void Init(const std::vector<int64_t> &kernel_size = {1}, const std::vector<int64_t> &strides = {1}, | ||||
| const PadMode &pad_mode = VALID); | |||||
| const PadMode &pad_mode = VALID, const Format &format = NCHW); | |||||
| virtual void set_kernel_size(const std::vector<int64_t> &kernel_size); | virtual void set_kernel_size(const std::vector<int64_t> &kernel_size); | ||||
| virtual void set_strides(const std::vector<int64_t> &strides); | virtual void set_strides(const std::vector<int64_t> &strides); | ||||
| void set_pad_mode(const PadMode &pad_mode); | void set_pad_mode(const PadMode &pad_mode); | ||||
| void set_format(const Format &format); | |||||
| std::vector<int64_t> get_kernel_size() const; | std::vector<int64_t> get_kernel_size() const; | ||||
| std::vector<int64_t> get_strides() const; | std::vector<int64_t> get_strides() const; | ||||
| PadMode get_pad_mode() const; | PadMode get_pad_mode() const; | ||||
| Format get_format() const; | |||||
| std::vector<int64_t> _grad_check_vector(const std::string arg_name, const std::vector<int64_t> arg_val, | std::vector<int64_t> _grad_check_vector(const std::string arg_name, const std::vector<int64_t> arg_val, | ||||
| const std::string op_name); | const std::string op_name); | ||||
| }; | }; | ||||
| @@ -26,46 +26,46 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| void Lrn::set_depth_radius(const int64_t depth_radius) { | |||||
| void LRN::set_depth_radius(const int64_t depth_radius) { | |||||
| CheckAndConvertUtils::CheckInteger(kDepthRadius, depth_radius, kGreaterEqual, 0, this->name()); | CheckAndConvertUtils::CheckInteger(kDepthRadius, depth_radius, kGreaterEqual, 0, this->name()); | ||||
| this->AddAttr(kDepthRadius, MakeValue(depth_radius)); | this->AddAttr(kDepthRadius, MakeValue(depth_radius)); | ||||
| } | } | ||||
| int64_t Lrn::get_depth_radius() const { | |||||
| int64_t LRN::get_depth_radius() const { | |||||
| auto value_ptr = GetAttr(kDepthRadius); | auto value_ptr = GetAttr(kDepthRadius); | ||||
| return GetValue<int64_t>(value_ptr); | return GetValue<int64_t>(value_ptr); | ||||
| } | } | ||||
| void Lrn::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); } | |||||
| void LRN::set_bias(const float bias) { this->AddAttr(kBias, MakeValue(bias)); } | |||||
| float Lrn::get_bias() const { | |||||
| float LRN::get_bias() const { | |||||
| auto value_ptr = GetAttr(kBias); | auto value_ptr = GetAttr(kBias); | ||||
| return GetValue<float>(value_ptr); | return GetValue<float>(value_ptr); | ||||
| } | } | ||||
| void Lrn::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); } | |||||
| void LRN::set_alpha(const float alpha) { this->AddAttr(kAlpha, MakeValue(alpha)); } | |||||
| float Lrn::get_alpha() const { | |||||
| float LRN::get_alpha() const { | |||||
| auto value_ptr = GetAttr(kAlpha); | auto value_ptr = GetAttr(kAlpha); | ||||
| return GetValue<float>(value_ptr); | return GetValue<float>(value_ptr); | ||||
| } | } | ||||
| void Lrn::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); } | |||||
| void LRN::set_beta(const float beta) { this->AddAttr(kBeta, MakeValue(beta)); } | |||||
| float Lrn::get_beta() const { | |||||
| float LRN::get_beta() const { | |||||
| auto value_ptr = GetAttr(kBeta); | auto value_ptr = GetAttr(kBeta); | ||||
| return GetValue<float>(value_ptr); | return GetValue<float>(value_ptr); | ||||
| } | } | ||||
| void Lrn::set_norm_region(const std::string &norm_region) { | |||||
| void LRN::set_norm_region(const std::string &norm_region) { | |||||
| CheckAndConvertUtils::CheckString(kNormRegion, norm_region, {"ACROSS_CHANNELS"}, this->name()); | CheckAndConvertUtils::CheckString(kNormRegion, norm_region, {"ACROSS_CHANNELS"}, this->name()); | ||||
| this->AddAttr(kNormRegion, MakeValue(norm_region)); | this->AddAttr(kNormRegion, MakeValue(norm_region)); | ||||
| } | } | ||||
| std::string Lrn::get_norm_region() const { | |||||
| std::string LRN::get_norm_region() const { | |||||
| auto value_ptr = GetAttr(kNormRegion); | auto value_ptr = GetAttr(kNormRegion); | ||||
| return GetValue<std::string>(value_ptr); | return GetValue<std::string>(value_ptr); | ||||
| } | } | ||||
| void Lrn::Init(const int64_t depth_radius, const float bias, const float alpha, const float beta, | |||||
| void LRN::Init(const int64_t depth_radius, const float bias, const float alpha, const float beta, | |||||
| const std::string &norm_region) { | const std::string &norm_region) { | ||||
| this->set_depth_radius(depth_radius); | this->set_depth_radius(depth_radius); | ||||
| this->set_bias(bias); | this->set_bias(bias); | ||||
| @@ -102,6 +102,7 @@ AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr | |||||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | ||||
| InferShape(primitive, input_args)->shape()); | InferShape(primitive, input_args)->shape()); | ||||
| } | } | ||||
| REGISTER_PRIMITIVE_C(kNameLrn, Lrn); | |||||
| REGISTER_PRIMITIVE_EVAL_IMPL(LRN, prim::kPrimLrn, LrnInfer); | |||||
| REGISTER_PRIMITIVE_C(kNameLRN, LRN); | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,12 +26,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ops { | namespace ops { | ||||
| constexpr auto kNameLrn = "Lrn"; | |||||
| class Lrn : public PrimitiveC { | |||||
| constexpr auto kNameLRN = "Lrn"; | |||||
| class LRN : public PrimitiveC { | |||||
| public: | public: | ||||
| Lrn() : PrimitiveC(kNameLrn) { InitIOName({"x"}, {"y"}); } | |||||
| ~Lrn() = default; | |||||
| MS_DECLARE_PARENT(Lrn, PrimitiveC); | |||||
| LRN() : PrimitiveC(kNameLRN) { InitIOName({"x"}, {"y"}); } | |||||
| ~LRN() = default; | |||||
| MS_DECLARE_PARENT(LRN, PrimitiveC); | |||||
| void Init(const int64_t depth_radius = 5, const float bias = 1.0, const float alpha = 1.0, const float beta = 0.5, | void Init(const int64_t depth_radius = 5, const float bias = 1.0, const float alpha = 1.0, const float beta = 0.5, | ||||
| const std::string &norm_region = "ACROSS_CHANNELS"); | const std::string &norm_region = "ACROSS_CHANNELS"); | ||||
| void set_depth_radius(const int64_t depth_radius); | void set_depth_radius(const int64_t depth_radius); | ||||
| @@ -47,7 +47,7 @@ class Lrn : public PrimitiveC { | |||||
| }; | }; | ||||
| AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr LrnInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const std::vector<AbstractBasePtr> &input_args); | const std::vector<AbstractBasePtr> &input_args); | ||||
| using PrimLrn = std::shared_ptr<Lrn>; | |||||
| using PrimLrn = std::shared_ptr<LRN>; | |||||
| } // namespace ops | } // namespace ops | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_OPS_LRN_H_ | #endif // MINDSPORE_CORE_OPS_LRN_H_ | ||||
| @@ -226,7 +226,6 @@ constexpr auto kReduceToEnd = "reduce_to_end"; | |||||
| constexpr auto kResetAfter = "reset_after"; | constexpr auto kResetAfter = "reset_after"; | ||||
| constexpr auto kCoeff = "coeff"; | constexpr auto kCoeff = "coeff"; | ||||
| constexpr auto kIsDepthWise = "is_depth_wise"; | constexpr auto kIsDepthWise = "is_depth_wise"; | ||||
| constexpr auto kIsDepthWiseNative = "is_depth_wise_native"; | |||||
| constexpr auto kZoneoutCell = "zoneout_cell"; | constexpr auto kZoneoutCell = "zoneout_cell"; | ||||
| constexpr auto kZoneoutHidden = "zoneout_hidden"; | constexpr auto kZoneoutHidden = "zoneout_hidden"; | ||||
| @@ -301,16 +301,16 @@ std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string | |||||
| raise_message(); | raise_message(); | ||||
| } | } | ||||
| } | } | ||||
| if (arg_value.size() == 1) { | |||||
| return ret_four ? std::vector<int64_t>{1, 1, arg_value[0], arg_value[0]} | |||||
| : std::vector<int64_t>{arg_value[0], arg_value[0]}; | |||||
| } | |||||
| if (arg_value.size() == 2) { | |||||
| return ret_four ? std::vector<int64_t>{1, 1, arg_value[0], arg_value[1]} : arg_value; | |||||
| } else if (arg_value.size() == 4 && allow_four) { | |||||
| return ret_four ? arg_value : std::vector<int64_t>{arg_value[2], arg_value[3]}; | |||||
| } | |||||
| raise_message(); | |||||
| // if (arg_value.size() == 1) { | |||||
| // return ret_four ? std::vector<int64_t>{1, 1, arg_value[0], arg_value[0]} | |||||
| // : std::vector<int64_t>{arg_value[0], arg_value[0]}; | |||||
| // } | |||||
| // if (arg_value.size() == 2) { | |||||
| // return ret_four ? std::vector<int64_t>{1, 1, arg_value[0], arg_value[1]} : arg_value; | |||||
| // } else if (arg_value.size() == 4 && allow_four) { | |||||
| // return ret_four ? arg_value : std::vector<int64_t>{arg_value[2], arg_value[3]}; | |||||
| // } | |||||
| // raise_message(); | |||||
| return arg_value; | return arg_value; | ||||
| } | } | ||||