From a0d25d8b57d2f7ad8c4e68548670155070c2ab32 Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Fri, 16 Apr 2021 09:54:28 +0800 Subject: [PATCH] avoid cast when doing infer --- mindspore/core/abstract/infer_functions.h | 20 --- mindspore/core/abstract/prim_nn.cc | 17 -- mindspore/core/ops/abs.cc | 4 +- mindspore/core/ops/adam.cc | 4 +- mindspore/core/ops/add.cc | 4 +- mindspore/core/ops/arg_max.cc | 4 +- mindspore/core/ops/arg_min.cc | 11 +- mindspore/core/ops/asin.cc | 4 +- mindspore/core/ops/assert.cc | 4 +- mindspore/core/ops/assign_add.cc | 4 +- mindspore/core/ops/atan.cc | 4 +- mindspore/core/ops/audio_spectrogram.cc | 24 +-- mindspore/core/ops/audio_spectrogram.h | 4 +- mindspore/core/ops/avg_pool.cc | 36 ++-- mindspore/core/ops/batch_norm.cc | 14 +- mindspore/core/ops/batch_norm_fold.cc | 4 +- mindspore/core/ops/batch_to_space.cc | 8 +- mindspore/core/ops/batch_to_space_nd.cc | 8 +- mindspore/core/ops/binary_cross_entropy.cc | 8 +- mindspore/core/ops/broadcast.cc | 4 +- mindspore/core/ops/concat.cc | 6 +- mindspore/core/ops/constant_of_shape.cc | 4 +- mindspore/core/ops/conv2d_transpose.cc | 4 +- mindspore/core/ops/cos.cc | 4 +- mindspore/core/ops/crop.cc | 4 +- mindspore/core/ops/custom_extract_features.cc | 4 +- mindspore/core/ops/custom_normalize.cc | 5 - mindspore/core/ops/custom_predict.cc | 4 +- mindspore/core/ops/depth_to_space.cc | 21 +-- mindspore/core/ops/depthwise_conv2d.cc | 67 +++----- mindspore/core/ops/detection_post_process.cc | 33 ++-- mindspore/core/ops/div.cc | 4 +- mindspore/core/ops/dropout.cc | 4 +- mindspore/core/ops/elu.cc | 4 +- mindspore/core/ops/equal.cc | 4 +- mindspore/core/ops/expand_dims.cc | 4 +- .../core/ops/fake_quant_with_min_max_vars.cc | 4 +- ...ake_quant_with_min_max_vars_per_channel.cc | 4 +- mindspore/core/ops/fft_imag.cc | 4 +- mindspore/core/ops/flatten.cc | 4 +- mindspore/core/ops/floor.cc | 4 +- mindspore/core/ops/fusion/add_fusion.cc | 4 +- mindspore/core/ops/fusion/avg_pool_fusion.cc | 15 +- mindspore/core/ops/fusion/full_connection.cc | 33 ++-- mindspore/core/ops/fusion/max_pool_fusion.cc | 15 +- mindspore/core/ops/fusion/pow_fusion.cc | 4 +- mindspore/core/ops/fusion/slice_fusion.cc | 4 +- mindspore/core/ops/gather_nd.cc | 4 +- mindspore/core/ops/gelu.cc | 4 +- mindspore/core/ops/grad/avg_pool_grad.cc | 2 - mindspore/core/ops/grad/batch_norm_grad.cc | 4 +- mindspore/core/ops/grad/bias_add_grad.cc | 4 +- .../ops/grad/binary_cross_entropy_grad.cc | 4 +- mindspore/core/ops/grad/dropout_grad.cc | 8 +- .../core/ops/grad/group_conv2d_grad_input.cc | 9 +- mindspore/core/ops/grad/max_pool_grad.cc | 5 +- .../sigmoid_cross_entropy_with_logits_grad.cc | 4 +- .../core/ops/grad/smooth_l1_loss_grad.cc | 4 +- mindspore/core/ops/hashtable_lookup.cc | 4 +- mindspore/core/ops/l2_normalize.cc | 11 +- mindspore/core/ops/less.cc | 4 +- mindspore/core/ops/less_equal.cc | 4 +- mindspore/core/ops/logical_and.cc | 4 +- mindspore/core/ops/logical_not.cc | 8 +- mindspore/core/ops/logical_or.cc | 4 +- mindspore/core/ops/lrn.cc | 4 +- mindspore/core/ops/lsh_projection.cc | 14 +- mindspore/core/ops/lstm.cc | 56 +++--- mindspore/core/ops/matrix_diag.cc | 4 +- mindspore/core/ops/max_pool.cc | 20 +-- mindspore/core/ops/maximum.cc | 4 +- mindspore/core/ops/merge.cc | 4 +- mindspore/core/ops/mfcc.cc | 12 +- mindspore/core/ops/non_max_suppression.cc | 3 - mindspore/core/ops/one_hot.cc | 15 +- mindspore/core/ops/ones_like.cc | 4 +- mindspore/core/ops/pack.cc | 11 +- mindspore/core/ops/pad.cc | 9 +- mindspore/core/ops/pow.cc | 4 +- mindspore/core/ops/prior_box.cc | 27 +-- mindspore/core/ops/quant_dtype_cast.cc | 15 +- mindspore/core/ops/range.cc | 16 +- mindspore/core/ops/rank.cc | 4 +- mindspore/core/ops/reciprocal.cc | 4 +- mindspore/core/ops/reduce.cc | 11 +- mindspore/core/ops/resize_bilinear.cc | 11 +- mindspore/core/ops/reverse_sequence.cc | 13 +- mindspore/core/ops/reverse_v2.cc | 4 +- mindspore/core/ops/rfft.cc | 11 +- mindspore/core/ops/roi_pooling.cc | 13 +- mindspore/core/ops/rsqrt.cc | 4 +- .../ops/sigmoid_cross_entropy_with_logits.cc | 4 +- mindspore/core/ops/skip_gram.cc | 4 +- mindspore/core/ops/smooth_l1_loss.cc | 4 +- .../ops/softmax_cross_entropy_with_logits.cc | 4 +- mindspore/core/ops/space_to_batch.cc | 11 +- mindspore/core/ops/space_to_batch_nd.cc | 11 +- ...parse_softmax_cross_entropy_with_logits.cc | 11 +- mindspore/core/ops/sparse_to_dense.cc | 4 +- mindspore/core/ops/squared_difference.cc | 4 +- mindspore/core/ops/squeeze.cc | 11 +- mindspore/core/ops/stack.cc | 11 +- mindspore/core/ops/strided_slice.cc | 162 +++++++++--------- mindspore/core/ops/sub.cc | 4 +- mindspore/core/ops/tan.cc | 4 +- mindspore/core/ops/tensor_list_from_tensor.cc | 4 +- mindspore/core/ops/tensor_list_stack.cc | 6 +- mindspore/core/ops/tile.cc | 4 +- mindspore/core/ops/topk.cc | 4 +- mindspore/core/ops/unpack.cc | 11 +- mindspore/core/ops/unsorted_segment_sum.cc | 4 +- mindspore/core/ops/unsqueeze.cc | 11 +- mindspore/core/ops/unstack.cc | 11 +- mindspore/core/ops/where.cc | 4 +- mindspore/core/ops/zeros_like.cc | 4 +- 115 files changed, 375 insertions(+), 786 deletions(-) diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index bb2f8e1ca5..4a645007b2 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -53,22 +53,10 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFastGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFastGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -153,10 +141,6 @@ AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr & const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); @@ -174,8 +158,6 @@ AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -302,8 +284,6 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 2916f498e2..a3c589866c 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -140,23 +140,6 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr } } -AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(primitive); - auto is_grad = GetValue(primitive->GetAttr("is_grad")); - CheckArgsSize(primitive->name(), args_spec_list, 2); - std::shared_ptr shape = std::make_shared(std::vector{}); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - if (is_grad) { - shape = args_spec_list[0]->BuildShape(); - } - auto type = args_spec_list[0]->BuildType(); - MS_EXCEPTION_IF_NULL(type); - auto type_tensor = type->cast(); - MS_EXCEPTION_IF_NULL(type_tensor); - return std::make_shared(type_tensor->element(), shape); -} - AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: five tensors(x, gamma, beta, mean, variance). diff --git a/mindspore/core/ops/abs.cc b/mindspore/core/ops/abs.cc index ed0574c38e..984175d324 100644 --- a/mindspore/core/ops/abs.cc +++ b/mindspore/core/ops/abs.cc @@ -30,9 +30,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto abs_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(abs_prim); - auto prim_name = abs_prim->name(); + auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/adam.cc b/mindspore/core/ops/adam.cc index 6754b84598..05dea617bf 100644 --- a/mindspore/core/ops/adam.cc +++ b/mindspore/core/ops/adam.cc @@ -23,9 +23,7 @@ namespace ops { namespace { abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto Adam_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(Adam_prim); - auto prim_name = Adam_prim->name(); + auto prim_name = primitive->name(); // infer shape auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name); diff --git a/mindspore/core/ops/add.cc b/mindspore/core/ops/add.cc index c041153f7f..e1de85f290 100644 --- a/mindspore/core/ops/add.cc +++ b/mindspore/core/ops/add.cc @@ -27,9 +27,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto add_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(add_prim); - auto prim_name = add_prim->name(); + auto prim_name = primitive->name(); return BroadCastInferShape(prim_name, input_args); } diff --git a/mindspore/core/ops/arg_max.cc b/mindspore/core/ops/arg_max.cc index c2f9537303..06e96df3f8 100644 --- a/mindspore/core/ops/arg_max.cc +++ b/mindspore/core/ops/arg_max.cc @@ -22,9 +22,7 @@ namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto axis = prim->get_axis(); + auto axis = GetValue(primitive->GetAttr(kAxis)); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto x_rank = SizeToLong(x_shape.size()); CheckAndConvertUtils::CheckInRange("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); diff --git a/mindspore/core/ops/arg_min.cc b/mindspore/core/ops/arg_min.cc index 230b61855d..def1e63c10 100644 --- a/mindspore/core/ops/arg_min.cc +++ b/mindspore/core/ops/arg_min.cc @@ -27,10 +27,7 @@ void ArgMin::Init(const int64_t axis, const TypeId output_type) { void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); } -int64_t ArgMin::get_axis() const { - auto value_ptr = GetAttr(kAxis); - return GetValue(value_ptr); -} +int64_t ArgMin::get_axis() const { return GetValue(GetAttr(kAxis)); } TypeId ArgMin::get_output_type() const { auto type_ptr = GetAttr(kOutputType)->cast()->element(); @@ -40,13 +37,11 @@ TypeId ArgMin::get_output_type() const { AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto argmin_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(argmin_prim); - auto prim_name = argmin_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name); // Infer shape - auto axis = argmin_prim->get_axis(); + auto axis = GetValue(primitive->GetAttr(kAxis)); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto x_rank = SizeToLong(x_shape.size()); CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); diff --git a/mindspore/core/ops/asin.cc b/mindspore/core/ops/asin.cc index 32c16249bc..ab75a30450 100644 --- a/mindspore/core/ops/asin.cc +++ b/mindspore/core/ops/asin.cc @@ -25,9 +25,7 @@ namespace ops { AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto asin_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(asin_prim); - auto prim_name = asin_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); // Infer Shape diff --git a/mindspore/core/ops/assert.cc b/mindspore/core/ops/assert.cc index 9831df4217..09a0a5028f 100644 --- a/mindspore/core/ops/assert.cc +++ b/mindspore/core/ops/assert.cc @@ -37,9 +37,7 @@ int64_t Assert::get_summarize() const { AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto Assert_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(Assert_prim); - auto op_name = Assert_prim->name(); + auto op_name = primitive->name(); TypePtr condition; if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) { auto condition_value = GetValue>(input_args[0]->BuildValue()); diff --git a/mindspore/core/ops/assign_add.cc b/mindspore/core/ops/assign_add.cc index 09aca18ffe..8d87c74eea 100644 --- a/mindspore/core/ops/assign_add.cc +++ b/mindspore/core/ops/assign_add.cc @@ -25,9 +25,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto assignadd_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(assignadd_prim); - auto prim_name = assignadd_prim->name(); + auto prim_name = primitive->name(); auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name); return std::make_shared(value_shape); diff --git a/mindspore/core/ops/atan.cc b/mindspore/core/ops/atan.cc index 21a014cdf4..1335c6476b 100644 --- a/mindspore/core/ops/atan.cc +++ b/mindspore/core/ops/atan.cc @@ -23,9 +23,7 @@ namespace ops { AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto atan_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(atan_prim); - auto prim_name = atan_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); // Infer Shape diff --git a/mindspore/core/ops/audio_spectrogram.cc b/mindspore/core/ops/audio_spectrogram.cc index b0f5d8c6ff..b8e023b57f 100644 --- a/mindspore/core/ops/audio_spectrogram.cc +++ b/mindspore/core/ops/audio_spectrogram.cc @@ -30,25 +30,25 @@ namespace { abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto audio_spectrogram_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(audio_spectrogram_prim); - auto prim_name = audio_spectrogram_prim->name(); + auto prim_name = primitive->name(); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); if (input_shape.size() != 2) { MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; } - if (audio_spectrogram_prim->get_window_size() < 2) { - MS_LOG(ERROR) << "window size is too short, now is " << audio_spectrogram_prim->get_window_size(); + auto window_size = GetValue(primitive->GetAttr(kWindowSize)); + if (window_size < 2) { + MS_LOG(ERROR) << "window size is too short, now is " << window_size; } - if (audio_spectrogram_prim->get_stride() < 1) { - MS_LOG(ERROR) << "stride must be positive, now is " << audio_spectrogram_prim->get_stride(); + auto stride_size = GetValue(primitive->GetAttr(kStride)); + if (stride_size < 1) { + MS_LOG(ERROR) << "stride must be positive, now is " << stride_size; } std::vector infer_shape; infer_shape.push_back(input_shape[1]); - int64_t sample_sub_window = input_shape[0] - audio_spectrogram_prim->get_window_size(); - infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / audio_spectrogram_prim->get_stride()); - int64_t fft_length = audio_spectrogram_prim->GetFftLength(audio_spectrogram_prim->get_window_size()); + int64_t sample_sub_window = input_shape[0] - window_size; + infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / stride_size); + int64_t fft_length = GetFftLength(window_size); infer_shape.push_back(fft_length / 2 + 1); MS_LOG(ERROR) << infer_shape; return std::make_shared(infer_shape); @@ -81,7 +81,7 @@ int64_t AudioSpectrogram::get_stride() const { return GetValue(value_ptr); } -int64_t AudioSpectrogram::Log2Ceil(int64_t length) { +int64_t Log2Ceil(int64_t length) { if (length == 0) { return -1; } @@ -97,7 +97,7 @@ int64_t AudioSpectrogram::Log2Ceil(int64_t length) { return length == (length & ~(unsigned int)(length - 1)) ? floor : floor + 1; } -int64_t AudioSpectrogram::GetFftLength(int64_t length) { +int64_t GetFftLength(int64_t length) { int64_t shift = Log2Ceil(length); return 1 << (unsigned int)shift; } diff --git a/mindspore/core/ops/audio_spectrogram.h b/mindspore/core/ops/audio_spectrogram.h index 7d4a11d71a..54173ccc3e 100644 --- a/mindspore/core/ops/audio_spectrogram.h +++ b/mindspore/core/ops/audio_spectrogram.h @@ -27,6 +27,8 @@ namespace mindspore { namespace ops { constexpr auto kNameAudioSpectrogram = "AudioSpectrogram"; +int64_t Log2Ceil(int64_t length); +int64_t GetFftLength(int64_t length); class AudioSpectrogram : public PrimitiveC { public: AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {} @@ -39,8 +41,6 @@ class AudioSpectrogram : public PrimitiveC { int64_t get_window_size() const; int64_t get_stride() const; bool get_mag_square() const; - int64_t Log2Ceil(int64_t length); - int64_t GetFftLength(int64_t length); }; AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); diff --git a/mindspore/core/ops/avg_pool.cc b/mindspore/core/ops/avg_pool.cc index 11462a7273..9c6abfad5f 100644 --- a/mindspore/core/ops/avg_pool.cc +++ b/mindspore/core/ops/avg_pool.cc @@ -31,37 +31,25 @@ void AvgPool::set_pad_mode(const PadMode &pad_mode) { this->AddAttr(kPadMode, MakeValue(swi)); } -PadMode AvgPool::get_pad_mode() const { - auto value_ptr = GetAttr(kPadMode); - return PadMode(GetValue(value_ptr)); -} +PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue(GetAttr(kPadMode))); } void AvgPool::set_kernel_size(const std::vector &kernel_size) { this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); } -std::vector AvgPool::get_kernel_size() const { - auto value_ptr = GetAttr(kKernelSize); - return GetValue>(value_ptr); -} +std::vector AvgPool::get_kernel_size() const { return GetValue>(GetAttr(kKernelSize)); } void AvgPool::set_strides(const std::vector &strides) { this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); } -std::vector AvgPool::get_strides() const { - auto value_ptr = GetAttr(kStrides); - return GetValue>(value_ptr); -} +std::vector AvgPool::get_strides() const { return GetValue>(GetAttr(kStrides)); } void AvgPool::set_format(const Format &format) { int64_t f = format; this->AddAttr(kFormat, MakeValue(f)); } -Format AvgPool::get_format() const { - auto value_ptr = GetAttr(kFormat); - return Format(GetValue(value_ptr)); -} +Format AvgPool::get_format() const { return Format(GetValue(GetAttr(kFormat))); } void AvgPool::set_pad(const std::vector &pad) { this->AddAttr(kPad, MakeValue(pad)); } @@ -93,22 +81,20 @@ void AvgPool::Init(const std::vector &kernel_size, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pool_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pool_prim); - auto op_name = pool_prim->name(); + auto op_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); - if (pool_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; } CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); - auto kernel_size = pool_prim->get_kernel_size(); - auto pad_mode = pool_prim->get_pad_mode(); + auto kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); + auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); auto batch = in_shape[0]; auto channel = in_shape[1]; auto in_h = in_shape[2]; auto in_w = in_shape[3]; - - auto strides = pool_prim->get_strides(); + auto strides = GetValue>(primitive->GetAttr(kStrides)); auto kernel_h = kernel_size[2]; auto kernel_w = kernel_size[3]; auto stride_h = strides[2]; @@ -123,7 +109,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector out_shape = {batch, channel, out_h, out_w}; - if (pool_prim->get_format() == NHWC) { + if (format == NHWC) { out_shape = {batch, out_h, out_w, channel}; } if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { diff --git a/mindspore/core/ops/batch_norm.cc b/mindspore/core/ops/batch_norm.cc index 75cba64e47..57c3d8e284 100644 --- a/mindspore/core/ops/batch_norm.cc +++ b/mindspore/core/ops/batch_norm.cc @@ -72,13 +72,12 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit const std::vector &input_args) { // Infer shape MS_EXCEPTION_IF_NULL(primitive); - auto batch_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(batch_prim); - auto prim_name = batch_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); - if (batch_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; } auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name); @@ -87,7 +86,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name); std::vector input_shape_norm; - if (batch_prim->get_format() == NCHW) { + if (format == NCHW) { input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); } else { @@ -100,7 +99,8 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError); CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name, TypeError); - if (!batch_prim->get_is_training()) { + + if (!GetValue(primitive->GetAttr(kIsTraining))) { CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name); CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError); CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError); @@ -126,7 +126,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit auto output1 = std::make_shared(scale_type, scale); auto output2 = std::make_shared(bias_type, scale); auto output3 = std::make_shared(input_x_type, scale); - if (batch_prim->get_format() == NHWC) { + if (format == NHWC) { output2 = std::make_shared(scale_type, scale); output3 = std::make_shared(bias_type, scale); output1 = std::make_shared(input_x_type, scale); diff --git a/mindspore/core/ops/batch_norm_fold.cc b/mindspore/core/ops/batch_norm_fold.cc index 097e85f947..359e0fd279 100644 --- a/mindspore/core/ops/batch_norm_fold.cc +++ b/mindspore/core/ops/batch_norm_fold.cc @@ -67,9 +67,7 @@ int64_t BatchNormFold::get_freeze_bn() const { AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto BatchNormFold_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(BatchNormFold_prim); - auto op_name = BatchNormFold_prim->name(); + auto op_name = primitive->name(); auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name); auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name); diff --git a/mindspore/core/ops/batch_to_space.cc b/mindspore/core/ops/batch_to_space.cc index 2a0a176b94..800928bb82 100644 --- a/mindspore/core/ops/batch_to_space.cc +++ b/mindspore/core/ops/batch_to_space.cc @@ -47,9 +47,7 @@ std::vector> BatchToSpace::get_crops() const { AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); @@ -59,8 +57,8 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); - auto block_size = prim->get_block_size(); - auto crops = prim->get_crops(); + auto block_size = GetValue>(primitive->GetAttr(kBlockSize)); + auto crops = GetValue>>(primitive->GetAttr(kCrops)); auto out_shape = x_shape; for (size_t i = 0; i < 2; ++i) { auto x_block_prod = out_shape[i + 2] * block_size[i]; diff --git a/mindspore/core/ops/batch_to_space_nd.cc b/mindspore/core/ops/batch_to_space_nd.cc index 28bffa73fb..06a7000ef4 100644 --- a/mindspore/core/ops/batch_to_space_nd.cc +++ b/mindspore/core/ops/batch_to_space_nd.cc @@ -28,16 +28,14 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto batch_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(batch_prim); - auto prim_name = batch_prim->name(); + auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); auto out_shape = x_shape; int64_t block_shape_prod = 1; int64_t offset = 2; - auto block_shape = batch_prim->get_block_shape(); - auto crops = batch_prim->get_crops(); + auto block_shape = GetValue>(primitive->GetAttr(kBlockShape)); + auto crops = GetValue>>(primitive->GetAttr(kCrops)); int64_t size = block_shape.size(); for (int64_t i = 0; i < size; i++) { block_shape_prod = block_shape_prod * block_shape[i]; diff --git a/mindspore/core/ops/binary_cross_entropy.cc b/mindspore/core/ops/binary_cross_entropy.cc index 52010fd663..4eef0762a5 100644 --- a/mindspore/core/ops/binary_cross_entropy.cc +++ b/mindspore/core/ops/binary_cross_entropy.cc @@ -32,9 +32,7 @@ namespace { abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto binary_cross_entropy_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(binary_cross_entropy_prim); - auto prim_name = binary_cross_entropy_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); @@ -45,8 +43,8 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive, if (weight_shape.size() < 1) { CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); } - if (binary_cross_entropy_prim->get_reduction() != REDUCTION_SUM && - binary_cross_entropy_prim->get_reduction() != MEAN) { + auto reduction = Reduction(GetValue(primitive->GetAttr(kReduction))); + if (reduction != REDUCTION_SUM && reduction != MEAN) { infer_shape = {x_shape.begin(), infer_shape.end()}; } return std::make_shared(infer_shape); diff --git a/mindspore/core/ops/broadcast.cc b/mindspore/core/ops/broadcast.cc index 4cf62b610e..67d6539015 100644 --- a/mindspore/core/ops/broadcast.cc +++ b/mindspore/core/ops/broadcast.cc @@ -45,9 +45,7 @@ std::string Broadcast::get_group() const { AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto broadcast_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(broadcast_prim); - auto prim_name = broadcast_prim->name(); + auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/concat.cc b/mindspore/core/ops/concat.cc index 6c9c65d58a..e74d8fdfcf 100644 --- a/mindspore/core/ops/concat.cc +++ b/mindspore/core/ops/concat.cc @@ -32,9 +32,7 @@ void Concat::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis) AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); @@ -48,7 +46,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name); auto element0_rank = SizeToLong(element0_shape.size()); - auto axis = prim->get_axis(); + auto axis = GetValue(primitive->GetAttr(kAxis)); CheckAndConvertUtils::CheckInRange("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, prim_name); axis = axis < 0 ? axis + element0_rank : axis; diff --git a/mindspore/core/ops/constant_of_shape.cc b/mindspore/core/ops/constant_of_shape.cc index dfea5b8c1f..fb3d711fa7 100644 --- a/mindspore/core/ops/constant_of_shape.cc +++ b/mindspore/core/ops/constant_of_shape.cc @@ -31,9 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto constant_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(constant_prim); - auto data_type = TypeId(constant_prim->get_data_type()); + auto data_type = TypeId(GetValue(primitive->GetAttr(kDataType))); return TypeIdToType(data_type); } } // namespace diff --git a/mindspore/core/ops/conv2d_transpose.cc b/mindspore/core/ops/conv2d_transpose.cc index 78ce9f82a6..8463fcc0d0 100644 --- a/mindspore/core/ops/conv2d_transpose.cc +++ b/mindspore/core/ops/conv2d_transpose.cc @@ -28,9 +28,7 @@ namespace { abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto conv2d_transpose_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(conv2d_transpose_prim); - auto prim_name = conv2d_transpose_prim->name(); + auto prim_name = primitive->name(); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name); return std::make_shared(input_shape); } diff --git a/mindspore/core/ops/cos.cc b/mindspore/core/ops/cos.cc index 77cbef9ce0..0d77b214fc 100644 --- a/mindspore/core/ops/cos.cc +++ b/mindspore/core/ops/cos.cc @@ -24,9 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto cos_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(cos_prim); - auto prim_name = cos_prim->name(); + auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/crop.cc b/mindspore/core/ops/crop.cc index 92ea29bc77..f96a2705c7 100644 --- a/mindspore/core/ops/crop.cc +++ b/mindspore/core/ops/crop.cc @@ -43,9 +43,7 @@ std::vector Crop::get_offsets() const { AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto crop_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(crop_prim); - auto prim_name = crop_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); diff --git a/mindspore/core/ops/custom_extract_features.cc b/mindspore/core/ops/custom_extract_features.cc index b7d7d18ac8..007b4ca6c8 100644 --- a/mindspore/core/ops/custom_extract_features.cc +++ b/mindspore/core/ops/custom_extract_features.cc @@ -24,9 +24,7 @@ namespace ops { AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto extract_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(extract_prim); - auto prim_name = extract_prim->name(); + auto prim_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]); // auto input = input_args[0]; diff --git a/mindspore/core/ops/custom_normalize.cc b/mindspore/core/ops/custom_normalize.cc index 0fa67598b9..b97aabcacd 100644 --- a/mindspore/core/ops/custom_normalize.cc +++ b/mindspore/core/ops/custom_normalize.cc @@ -24,13 +24,8 @@ namespace { abstract::ShapePtr CustomNormalizeInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto custom_normalize_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(custom_normalize_prim); - // auto prim_name = custom_normalize_prim->name(); MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[0]->BuildShape()); - // auto input_shape = - // CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); if (input_args[0]->BuildValue()->cast()->data_c() == nullptr) { MS_LOG(ERROR) << "Do infer shape in runtime."; } diff --git a/mindspore/core/ops/custom_predict.cc b/mindspore/core/ops/custom_predict.cc index e28c3af1ae..9d3f86d9a1 100644 --- a/mindspore/core/ops/custom_predict.cc +++ b/mindspore/core/ops/custom_predict.cc @@ -45,13 +45,11 @@ float CustomPredict::get_weight_threshold() const { AbstractBasePtr CustomPredictInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto CustomPredict_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(CustomPredict_prim); for (const auto &input : input_args) { MS_EXCEPTION_IF_NULL(input); } std::vector shape; - shape.push_back(CustomPredict_prim->get_output_num()); + shape.push_back(GetValue(primitive->GetAttr(kOutputNum))); auto output0 = std::make_shared(kInt32, shape); auto output1 = std::make_shared(kFloat32, shape); diff --git a/mindspore/core/ops/depth_to_space.cc b/mindspore/core/ops/depth_to_space.cc index 93a0233cd2..229df9c394 100644 --- a/mindspore/core/ops/depth_to_space.cc +++ b/mindspore/core/ops/depth_to_space.cc @@ -30,19 +30,13 @@ void DepthToSpace::set_block_size(const int64_t block_size) { this->AddAttr(kBlockSize, MakeValue(block_size)); } -int64_t DepthToSpace::get_block_size() const { - auto value_ptr = GetAttr(kBlockSize); - return GetValue(value_ptr); -} +int64_t DepthToSpace::get_block_size() const { return GetValue(GetAttr(kBlockSize)); } void DepthToSpace::set_format(const Format &format) { int64_t f = format; this->AddAttr(kFormat, MakeValue(f)); } -Format DepthToSpace::get_format() const { - auto value_ptr = GetAttr(kFormat); - return Format(GetValue(value_ptr)); -} +Format DepthToSpace::get_format() const { return Format(GetValue(GetAttr(kFormat))); } void DepthToSpace::Init(const int64_t block_size, const Format &format) { this->set_block_size(block_size); @@ -52,9 +46,7 @@ void DepthToSpace::Init(const int64_t block_size, const Format &format) { AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); @@ -63,18 +55,19 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri MS_EXCEPTION_IF_NULL(input_x); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - if (prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; } CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); - int64_t block_size = prim->get_block_size(); + int64_t block_size = GetValue(primitive->GetAttr(kBlockSize)); CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size), kEqual, 0, prim_name); auto out_shape = x_shape; out_shape[1] /= block_size * block_size; out_shape[2] *= block_size; out_shape[3] *= block_size; - if (prim->get_format() == NHWC) { + if (format == NHWC) { out_shape = {out_shape[0], out_shape[2], out_shape[3], out_shape[1]}; } auto ret = input_x->Broaden(); diff --git a/mindspore/core/ops/depthwise_conv2d.cc b/mindspore/core/ops/depthwise_conv2d.cc index 3b4b5e5961..cf600fbb99 100644 --- a/mindspore/core/ops/depthwise_conv2d.cc +++ b/mindspore/core/ops/depthwise_conv2d.cc @@ -65,25 +65,14 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector DepthWiseConv2D::get_kernel_size() const { - auto value_ptr = GetAttr(kKernelSize); - return GetValue>(value_ptr); -} -std::vector DepthWiseConv2D::get_stride() const { - auto value_ptr = GetAttr(kStride); - return GetValue>(value_ptr); + return GetValue>(GetAttr(kKernelSize)); } +std::vector DepthWiseConv2D::get_stride() const { return GetValue>(GetAttr(kStride)); } std::vector DepthWiseConv2D::get_dilation() const { - auto value_ptr = GetAttr(kDilation); - return GetValue>(value_ptr); -} -PadMode DepthWiseConv2D::get_pad_mode() const { - auto value_ptr = this->GetAttr(kPadMode); - return PadMode(GetValue(value_ptr)); -} -std::vector DepthWiseConv2D::get_pad() const { - auto value_ptr = this->GetAttr(kPad); - return GetValue>(value_ptr); + return GetValue>(GetAttr(kDilation)); } +PadMode DepthWiseConv2D::get_pad_mode() const { return PadMode(GetValue(GetAttr(kPadMode))); } +std::vector DepthWiseConv2D::get_pad() const { return GetValue>(GetAttr(kPad)); } std::vector DepthWiseConv2D::get_pads() const { auto value_ptr = this->GetAttr(kPads); @@ -99,10 +88,7 @@ int64_t DepthWiseConv2D::get_group() const { auto value_ptr = this->GetAttr(kGroup); return GetValue(value_ptr); } -int64_t DepthWiseConv2D::get_out_channel() const { - auto value_ptr = this->GetAttr(kOutChannel); - return GetValue(value_ptr); -} +int64_t DepthWiseConv2D::get_out_channel() const { return GetValue(GetAttr(kOutChannel)); } void DepthWiseConv2D::set_kernel_size(const std::vector &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); @@ -126,33 +112,29 @@ void DepthWiseConv2D::set_format(const Format &format) { this->AddAttr(kFormat, MakeValue(f)); } -Format DepthWiseConv2D::get_format() const { - auto value_ptr = GetAttr(kFormat); - return Format(GetValue(value_ptr)); -} +Format DepthWiseConv2D::get_format() const { return Format(GetValue(GetAttr(kFormat))); } abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto conv_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(conv_prim); - auto prim_name = conv_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInRange("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); - if (conv_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; } CheckAndConvertUtils::CheckInteger("weight_rank", w_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("x_rank", x_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], conv_prim->name()); - auto out_channel = conv_prim->get_out_channel(); + CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], prim_name); + auto out_channel = GetValue(primitive->GetAttr(kOutChannel)); std::vector temp_w; std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); - CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w, - conv_prim->name()); + CheckAndConvertUtils::Check("kernel_size", GetValue>(primitive->GetAttr(kKernelSize)), kEqual, + "w_shape[2:4]", temp_w, prim_name); auto kernel_size_n = w_shape[0]; if (kernel_size_n != 1) { @@ -160,8 +142,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, } auto kernel_size_h = w_shape[2]; auto kernel_size_w = w_shape[3]; - auto stride = conv_prim->get_stride(); - auto dilation = conv_prim->get_dilation(); + auto stride = GetValue>(primitive->GetAttr(kStride)); + auto dilation = GetValue>(primitive->GetAttr(kDilation)); auto stride_h = stride[2]; auto stride_w = stride[3]; auto dilation_h = dilation[2]; @@ -169,7 +151,7 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, int64_t h_out = -1; int64_t w_out = -1; std::vector pad_list(4, 0); - auto pad_mode = conv_prim->get_pad_mode(); + auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); if (pad_mode == VALID) { h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); @@ -187,20 +169,21 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive, pad_list.emplace_back(pad_left); pad_list.emplace_back(pad_needed_h - pad_left); } else if (pad_mode == PAD) { - std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list)); - auto pad_top = conv_prim->get_pad()[0]; - auto pad_bottom = conv_prim->get_pad()[1]; - auto pad_right = conv_prim->get_pad()[2]; - auto pad_left = conv_prim->get_pad()[3]; + auto pads = GetValue>(primitive->GetAttr(kPad)); + std::copy(pads.begin(), pads.end(), std::back_inserter(pad_list)); + auto pad_top = pads[0]; + auto pad_bottom = pads[1]; + auto pad_right = pads[2]; + auto pad_left = pads[3]; h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; h_out = floor(h_out); w_out = floor(w_out); } - conv_prim->set_pads(pad_list); + primitive->AddAttr(kPads, MakeValue(pad_list)); std::vector out_shape = {x_shape[0], out_channel * x_shape[1], h_out, w_out}; - if (conv_prim->get_format() == NHWC) { + if (format == NHWC) { out_shape = {x_shape[0], h_out, w_out, out_channel * x_shape[1]}; } return std::make_shared(out_shape); diff --git a/mindspore/core/ops/detection_post_process.cc b/mindspore/core/ops/detection_post_process.cc index 19f54ca07b..37cb666987 100644 --- a/mindspore/core/ops/detection_post_process.cc +++ b/mindspore/core/ops/detection_post_process.cc @@ -68,10 +68,7 @@ float DetectionPostProcess::get_nms_score_threshold() const { void DetectionPostProcess::set_max_detections(const int64_t MaxDetections) { this->AddAttr(kMaxDetections, MakeValue(MaxDetections)); } -int64_t DetectionPostProcess::get_max_detections() const { - auto value_ptr = this->GetAttr(kMaxDetections); - return GetValue(value_ptr); -} +int64_t DetectionPostProcess::get_max_detections() const { return GetValue(GetAttr(kMaxDetections)); } void DetectionPostProcess::set_detections_per_class(const int64_t DetectionsPerClass) { this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass)); @@ -85,17 +82,13 @@ void DetectionPostProcess::set_max_classes_per_detection(const int64_t MaxClasse this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection)); } int64_t DetectionPostProcess::get_max_classes_per_detection() const { - auto value_ptr = this->GetAttr(kMaxClassesPerDetection); - return GetValue(value_ptr); + return GetValue(GetAttr(kMaxClassesPerDetection)); } void DetectionPostProcess::set_num_classes(const int64_t NumClasses) { this->AddAttr(kNumClasses, MakeValue(NumClasses)); } -int64_t DetectionPostProcess::get_num_classes() const { - auto value_ptr = this->GetAttr(kNumClasses); - return GetValue(value_ptr); -} +int64_t DetectionPostProcess::get_num_classes() const { return GetValue(GetAttr(kNumClasses)); } void DetectionPostProcess::set_use_regular_nms(const bool UseRegularNms) { this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms)); } @@ -115,16 +108,11 @@ void DetectionPostProcess::set_format(const Format &format) { int64_t f = format; this->AddAttr(kFormat, MakeValue(f)); } -Format DetectionPostProcess::get_format() const { - auto value_ptr = this->GetAttr(kFormat); - return Format(GetValue(value_ptr)); -} +Format DetectionPostProcess::get_format() const { return Format(GetValue(GetAttr(kFormat))); } AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto detection_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(detection_prim); - auto prim_name = detection_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("detection_post_process_infer", input_args.size(), kEqual, 3, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[1]); @@ -135,12 +123,13 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name); auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name); auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name); - if (detection_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]}; scores_shape = {scores_shape[0], scores_shape[3], scores_shape[1], scores_shape[2]}; anchors_shape = {anchors_shape[0], anchors_shape[3], anchors_shape[1], anchors_shape[2]}; } - auto num_classes = detection_prim->get_num_classes(); + auto num_classes = GetValue(primitive->GetAttr(kNumClasses)); CheckAndConvertUtils::CheckInRange("scores_shape[2]", scores_shape[2], kIncludeBoth, {num_classes, num_classes + 1}, prim_name); CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "scores_shape[1]", scores_shape[1], prim_name, @@ -149,8 +138,8 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c ValueError); // Infer shape - auto max_detections = detection_prim->get_max_detections(); - auto max_classes_per_detection = detection_prim->get_max_classes_per_detection(); + auto max_detections = GetValue(primitive->GetAttr(kMaxDetections)); + auto max_classes_per_detection = GetValue(primitive->GetAttr(kMaxClassesPerDetection)); auto num_detected_boxes = max_detections * max_classes_per_detection; std::vector output_boxes_shape = {1, num_detected_boxes, 4}; std::vector output_class_shape = {1, num_detected_boxes}; @@ -163,7 +152,7 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c auto output1 = std::make_shared(output_type, output_class_shape); auto output2 = std::make_shared(output_type, output_num_shape); AbstractBasePtrList output = {output0, output1, output1, output2}; - if (detection_prim->get_format() == NHWC) { + if (format == NHWC) { output = {output0, output1, output2, output1}; } return std::make_shared(output); diff --git a/mindspore/core/ops/div.cc b/mindspore/core/ops/div.cc index affcb3782d..0b7ac8a664 100644 --- a/mindspore/core/ops/div.cc +++ b/mindspore/core/ops/div.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto div_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(div_prim); - auto prim_name = div_prim->name(); + auto prim_name = primitive->name(); return BroadCastInferShape(prim_name, input_args); } diff --git a/mindspore/core/ops/dropout.cc b/mindspore/core/ops/dropout.cc index 13a7093616..f6671e687a 100644 --- a/mindspore/core/ops/dropout.cc +++ b/mindspore/core/ops/dropout.cc @@ -39,9 +39,7 @@ float Dropout::get_keep_prob() const { AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto dropout_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(dropout_prim); - auto prim_name = dropout_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name); // Infer shape diff --git a/mindspore/core/ops/elu.cc b/mindspore/core/ops/elu.cc index 7e0c16b197..59a3c3ad87 100644 --- a/mindspore/core/ops/elu.cc +++ b/mindspore/core/ops/elu.cc @@ -31,9 +31,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto elu_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(elu_prim); - auto op_name = elu_prim->name(); + auto op_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/equal.cc b/mindspore/core/ops/equal.cc index f44d0742c7..b9066430be 100644 --- a/mindspore/core/ops/equal.cc +++ b/mindspore/core/ops/equal.cc @@ -29,9 +29,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto equal_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(equal_prim); - auto op_name = equal_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/expand_dims.cc b/mindspore/core/ops/expand_dims.cc index c813af5c45..8e1999ea20 100644 --- a/mindspore/core/ops/expand_dims.cc +++ b/mindspore/core/ops/expand_dims.cc @@ -30,9 +30,7 @@ namespace ops { AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto expand_dims_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(expand_dims_prim); - auto prim_name = expand_dims_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars.cc b/mindspore/core/ops/fake_quant_with_min_max_vars.cc index 3933f0647d..495a185f7c 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto fake_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(fake_prim); - auto prim_name = fake_prim->name(); + auto prim_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name); auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name); diff --git a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc index 80dc976c41..da25c34092 100644 --- a/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc +++ b/mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc @@ -43,9 +43,7 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto FakeQuantWithMinMaxVarsPerChannel_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(FakeQuantWithMinMaxVarsPerChannel_prim); - auto op_name = FakeQuantWithMinMaxVarsPerChannel_prim->name(); + auto op_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name); auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name); diff --git a/mindspore/core/ops/fft_imag.cc b/mindspore/core/ops/fft_imag.cc index fb2545041c..762dd45914 100644 --- a/mindspore/core/ops/fft_imag.cc +++ b/mindspore/core/ops/fft_imag.cc @@ -24,9 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto FftImag_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(FftImag_prim); - auto prim_name = FftImag_prim->name(); + auto prim_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name); in_shape.pop_back(); return std::make_shared(in_shape); diff --git a/mindspore/core/ops/flatten.cc b/mindspore/core/ops/flatten.cc index 9c09f8f7c9..c18f39f8dd 100644 --- a/mindspore/core/ops/flatten.cc +++ b/mindspore/core/ops/flatten.cc @@ -23,9 +23,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto flatten_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(flatten_prim); - auto prim_name = flatten_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto prod = 1; diff --git a/mindspore/core/ops/floor.cc b/mindspore/core/ops/floor.cc index 56bc0c6370..52ac1e0b3d 100644 --- a/mindspore/core/ops/floor.cc +++ b/mindspore/core/ops/floor.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto floor_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(floor_prim); - auto prim_name = floor_prim->name(); + auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/fusion/add_fusion.cc b/mindspore/core/ops/fusion/add_fusion.cc index 384c2f3007..1635fd58cc 100644 --- a/mindspore/core/ops/fusion/add_fusion.cc +++ b/mindspore/core/ops/fusion/add_fusion.cc @@ -39,9 +39,7 @@ void AddFusion::Init(const ActivationType activation_type) { this->set_activatio namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto add_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(add_prim); - auto op_name = add_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/fusion/avg_pool_fusion.cc b/mindspore/core/ops/fusion/avg_pool_fusion.cc index 6013bbfcf0..2b333a2577 100644 --- a/mindspore/core/ops/fusion/avg_pool_fusion.cc +++ b/mindspore/core/ops/fusion/avg_pool_fusion.cc @@ -52,22 +52,21 @@ ActivationType AvgPoolFusion::get_activation_type() const { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pool_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pool_prim); - auto op_name = pool_prim->name(); + auto op_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); - if (pool_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; } CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); - auto kernel_size = pool_prim->get_kernel_size(); - auto pad_mode = pool_prim->get_pad_mode(); + auto kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); + auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); auto batch = in_shape[0]; auto channel = in_shape[1]; auto in_h = in_shape[2]; auto in_w = in_shape[3]; - auto strides = pool_prim->get_strides(); + auto strides = GetValue>(primitive->GetAttr(kStrides)); auto kernel_h = kernel_size[2]; auto kernel_w = kernel_size[3]; auto stride_h = strides[2]; @@ -82,7 +81,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector out_shape = {batch, channel, out_h, out_w}; - if (pool_prim->get_format() == NHWC) { + if (format == NHWC) { out_shape = {batch, out_h, out_w, channel}; } if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { diff --git a/mindspore/core/ops/fusion/full_connection.cc b/mindspore/core/ops/fusion/full_connection.cc index 8cbe0420c7..602d26e7af 100644 --- a/mindspore/core/ops/fusion/full_connection.cc +++ b/mindspore/core/ops/fusion/full_connection.cc @@ -21,22 +21,13 @@ namespace mindspore { namespace ops { void FullConnection::set_has_bias(const bool has_bias) { this->AddAttr(kHasBias, MakeValue(has_bias)); } -bool FullConnection::get_has_bias() const { - auto value_ptr = GetAttr(kHasBias); - return GetValue(value_ptr); -} +bool FullConnection::get_has_bias() const { return GetValue(GetAttr(kHasBias)); } void FullConnection::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } -int64_t FullConnection::get_axis() const { - auto value_ptr = GetAttr(kAxis); - return GetValue(value_ptr); -} +int64_t FullConnection::get_axis() const { return GetValue(GetAttr(kAxis)); } void FullConnection::set_use_axis(const bool use_axis) { this->AddAttr(kUseAxis, MakeValue(use_axis)); } -bool FullConnection::get_use_axis() const { - auto value_ptr = GetAttr(kUseAxis); - return GetValue(value_ptr); -} +bool FullConnection::get_use_axis() const { return GetValue(GetAttr(kUseAxis)); } void FullConnection::set_activation_type(const ActivationType &activation_type) { int64_t swi; @@ -57,26 +48,26 @@ void FullConnection::Init(const bool has_bias, const int64_t axis, const bool us AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto full_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(full_prim); - auto prim_name = full_prim->name(); + auto prim_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[1]); auto input0 = input_args[0]; auto input1 = input_args[1]; auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name); auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name); - auto prim_axis = full_prim->get_axis(); - if (full_prim->get_has_bias()) { + auto prim_axis = GetValue(primitive->GetAttr(kAxis)); + auto has_bias = GetValue(primitive->GetAttr(kHasBias)); + if (has_bias) { CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name); } else { CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name); } - if (full_prim->get_use_axis() && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) { + auto use_axis = GetValue(primitive->GetAttr(kUseAxis)); + if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) { MS_EXCEPTION(ValueError) << "Full Connection axis invalid"; } int64_t new_k = 1; - if (full_prim->get_use_axis()) { + if (use_axis) { for (size_t t = prim_axis; t < input0_shape.size(); t++) { new_k *= input0_shape[t]; } @@ -86,7 +77,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P } else { new_k = input1_shape[1]; } - if (full_prim->get_has_bias()) { + if (has_bias) { auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name); if (input2_shape[0] != input1_shape[0]) { @@ -94,7 +85,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P } } std::vector out_shape = {(int64_t)input0_shape.size()}; - if (full_prim->get_use_axis()) { + if (use_axis) { out_shape.resize(prim_axis + 1); out_shape[prim_axis] = input1_shape[0]; } else { diff --git a/mindspore/core/ops/fusion/max_pool_fusion.cc b/mindspore/core/ops/fusion/max_pool_fusion.cc index b8545e021e..c0bdc079a7 100644 --- a/mindspore/core/ops/fusion/max_pool_fusion.cc +++ b/mindspore/core/ops/fusion/max_pool_fusion.cc @@ -52,22 +52,21 @@ ActivationType MaxPoolFusion::get_activation_type() const { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pool_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pool_prim); - auto op_name = pool_prim->name(); + auto op_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name); - if (pool_prim->get_format() == NHWC) { + auto format = Format(GetValue(primitive->GetAttr(kFormat))); + if (format == NHWC) { in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; } CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); - auto kernel_size = pool_prim->get_kernel_size(); - auto pad_mode = pool_prim->get_pad_mode(); + auto kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); + auto pad_mode = PadMode(GetValue(primitive->GetAttr(kPadMode))); auto batch = in_shape[0]; auto channel = in_shape[1]; auto in_h = in_shape[2]; auto in_w = in_shape[3]; - auto strides = pool_prim->get_strides(); + auto strides = GetValue>(primitive->GetAttr(kStrides)); auto kernel_h = kernel_size[2]; auto kernel_w = kernel_size[3]; auto stride_h = strides[2]; @@ -82,7 +81,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector out_shape = {batch, channel, out_h, out_w}; - if (pool_prim->get_format() == NHWC) { + if (format == NHWC) { out_shape = {batch, out_h, out_w, channel}; } if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { diff --git a/mindspore/core/ops/fusion/pow_fusion.cc b/mindspore/core/ops/fusion/pow_fusion.cc index 9fadab92b7..648c174296 100644 --- a/mindspore/core/ops/fusion/pow_fusion.cc +++ b/mindspore/core/ops/fusion/pow_fusion.cc @@ -37,9 +37,7 @@ float PowFusion::get_shift() const { return GetValue(GetAttr(kShift)); } namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pow_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pow_prim); - auto op_name = pow_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/fusion/slice_fusion.cc b/mindspore/core/ops/fusion/slice_fusion.cc index af4db80ae1..dd1af5c812 100644 --- a/mindspore/core/ops/fusion/slice_fusion.cc +++ b/mindspore/core/ops/fusion/slice_fusion.cc @@ -32,9 +32,7 @@ std::vector SliceFusion::get_axes() const { AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto SliceFusion_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(SliceFusion_prim); - auto op_name = SliceFusion_prim->name(); + auto op_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name); auto x_shape_len = (int64_t)x_shape.size(); auto begin_v = input_args[1]->BuildValue(); diff --git a/mindspore/core/ops/gather_nd.cc b/mindspore/core/ops/gather_nd.cc index 12760e195e..53bd1fa51b 100644 --- a/mindspore/core/ops/gather_nd.cc +++ b/mindspore/core/ops/gather_nd.cc @@ -27,9 +27,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto gather_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(gather_prim); - auto prim_name = gather_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); diff --git a/mindspore/core/ops/gelu.cc b/mindspore/core/ops/gelu.cc index e9a429d424..f48a1f3252 100644 --- a/mindspore/core/ops/gelu.cc +++ b/mindspore/core/ops/gelu.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto gelu_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(gelu_prim); - auto prim_name = gelu_prim->name(); + auto prim_name = primitive->name(); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name); return std::make_shared(input_shape); } diff --git a/mindspore/core/ops/grad/avg_pool_grad.cc b/mindspore/core/ops/grad/avg_pool_grad.cc index 13a78a904d..b1688449f2 100644 --- a/mindspore/core/ops/grad/avg_pool_grad.cc +++ b/mindspore/core/ops/grad/avg_pool_grad.cc @@ -22,8 +22,6 @@ namespace ops { AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto AvgPoolGrad_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(AvgPoolGrad_prim); MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); auto origin_input_shape = GetValue>(input_args[0]->BuildValue()); auto tensor_type = input_args[1]->BuildType()->cast(); diff --git a/mindspore/core/ops/grad/batch_norm_grad.cc b/mindspore/core/ops/grad/batch_norm_grad.cc index 738d2fb20f..407dca7db1 100644 --- a/mindspore/core/ops/grad/batch_norm_grad.cc +++ b/mindspore/core/ops/grad/batch_norm_grad.cc @@ -47,9 +47,7 @@ bool BatchNormGrad::get_is_training() const { AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto BatchNormGrad_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(BatchNormGrad_prim); - auto op_name = BatchNormGrad_prim->name(); + auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[1]); MS_EXCEPTION_IF_NULL(input_args[2]); MS_EXCEPTION_IF_NULL(input_args[3]); diff --git a/mindspore/core/ops/grad/bias_add_grad.cc b/mindspore/core/ops/grad/bias_add_grad.cc index 9ff1c387bd..0612806c9f 100644 --- a/mindspore/core/ops/grad/bias_add_grad.cc +++ b/mindspore/core/ops/grad/bias_add_grad.cc @@ -41,9 +41,7 @@ Format BiasAddGrad::get_format() const { AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto bias_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(bias_prim); - auto prim_name = bias_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); diff --git a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc index fd20e41e70..a3d17aa74f 100644 --- a/mindspore/core/ops/grad/binary_cross_entropy_grad.cc +++ b/mindspore/core/ops/grad/binary_cross_entropy_grad.cc @@ -26,9 +26,7 @@ namespace { abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto binary_cross_entropy_grad_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(binary_cross_entropy_grad_prim); - auto prim_name = binary_cross_entropy_grad_prim->name(); + auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name); auto weight_shape = diff --git a/mindspore/core/ops/grad/dropout_grad.cc b/mindspore/core/ops/grad/dropout_grad.cc index d73fa9ea05..03fadb0837 100644 --- a/mindspore/core/ops/grad/dropout_grad.cc +++ b/mindspore/core/ops/grad/dropout_grad.cc @@ -35,18 +35,14 @@ namespace { abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto DropoutGrad_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(DropoutGrad_prim); - auto op_name = DropoutGrad_prim->name(); + auto op_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); return std::make_shared(in_shape); } TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); - auto DropoutGrad_prim = prim->cast(); - MS_EXCEPTION_IF_NULL(DropoutGrad_prim); - auto op_name = DropoutGrad_prim->name(); + auto op_name = prim->name(); auto mask_dtype = input_args[1]->BuildType(); auto dy_dtype = input_args[0]->BuildType(); CheckAndConvertUtils::CheckTensorTypeValid("mask", mask_dtype, {kTensorType}, op_name); diff --git a/mindspore/core/ops/grad/group_conv2d_grad_input.cc b/mindspore/core/ops/grad/group_conv2d_grad_input.cc index f4f553c380..f74924d01f 100644 --- a/mindspore/core/ops/grad/group_conv2d_grad_input.cc +++ b/mindspore/core/ops/grad/group_conv2d_grad_input.cc @@ -114,8 +114,7 @@ void GroupConv2DGradInput::set_input_shape(const std::vector &input_sha } std::vector GroupConv2DGradInput::get_input_shape() const { - auto value_ptr = GetAttr(kInputShape); - return GetValue>(value_ptr); + return GetValue>(GetAttr(kInputShape)); } void GroupConv2DGradInput::set_format(const Format &format) { @@ -147,14 +146,12 @@ bool GroupConv2DGradInput::get_has_bias() const { AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto group_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(group_prim); - auto prim_name = group_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", input_args.size(), kGreaterEqual, 2, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); // Infer shape - auto shape = group_prim->get_input_shape(); + auto shape = GetValue>(primitive->GetAttr(kInputShape)); // Infer type auto type = input_args[0]->BuildType()->cast()->element(); diff --git a/mindspore/core/ops/grad/max_pool_grad.cc b/mindspore/core/ops/grad/max_pool_grad.cc index d3575c24df..64120ab95f 100644 --- a/mindspore/core/ops/grad/max_pool_grad.cc +++ b/mindspore/core/ops/grad/max_pool_grad.cc @@ -21,9 +21,8 @@ namespace mindspore { namespace ops { AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - auto MaxPoolGrad_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim); - auto op_name = MaxPoolGrad_prim->name(); + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name); auto tensor_type = input_args[0]->BuildType()->cast(); diff --git a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc index bd559e3798..c08730a285 100644 --- a/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc +++ b/mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc @@ -30,9 +30,7 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto sigmoid_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(sigmoid_prim); - auto prim_name = sigmoid_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", input_args.size(), kEqual, 3, prim_name); diff --git a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc index 26585f34ed..59215e6041 100644 --- a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc +++ b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc @@ -36,9 +36,7 @@ float SmoothL1LossGrad::get_beta() const { AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto smooth_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(smooth_prim); - auto prim_name = smooth_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name); // Infer shape diff --git a/mindspore/core/ops/hashtable_lookup.cc b/mindspore/core/ops/hashtable_lookup.cc index d3e16efc98..1f039b5814 100644 --- a/mindspore/core/ops/hashtable_lookup.cc +++ b/mindspore/core/ops/hashtable_lookup.cc @@ -24,12 +24,10 @@ namespace ops { AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto HashtableLookup_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(HashtableLookup_prim); for (auto input : input_args) { MS_EXCEPTION_IF_NULL(input); } - auto op_name = HashtableLookup_prim->name(); + auto op_name = primitive->name(); std::vector hits_shape; auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); hits_shape.push_back(input[0]); diff --git a/mindspore/core/ops/l2_normalize.cc b/mindspore/core/ops/l2_normalize.cc index 08d4ce599d..b5e7ef0fe7 100644 --- a/mindspore/core/ops/l2_normalize.cc +++ b/mindspore/core/ops/l2_normalize.cc @@ -29,10 +29,7 @@ void L2Normalize::set_axis(const std::vector &axis) { AddAttr(kAxis, Ma void L2Normalize::set_epsilon(const float epsilon) { AddAttr(kEpsilon, MakeValue(epsilon)); } -std::vector L2Normalize::get_axis() const { - auto value_ptr = GetAttr(kAxis); - return GetValue>(value_ptr); -} +std::vector L2Normalize::get_axis() const { return GetValue>(GetAttr(kAxis)); } float L2Normalize::get_epsilon() const { auto value_ptr = GetAttr(kEpsilon); @@ -42,9 +39,7 @@ float L2Normalize::get_epsilon() const { AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); @@ -53,7 +48,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto x_rank = SizeToLong(x_shape.size()); - auto axiss = prim->get_axis(); + auto axiss = GetValue>(primitive->GetAttr(kAxis)); for (auto &axis : axiss) { CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); } diff --git a/mindspore/core/ops/less.cc b/mindspore/core/ops/less.cc index 6c5fc4e31a..e994bfd986 100644 --- a/mindspore/core/ops/less.cc +++ b/mindspore/core/ops/less.cc @@ -27,9 +27,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto less_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(less_prim); - auto op_name = less_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/less_equal.cc b/mindspore/core/ops/less_equal.cc index 0a8c87f664..214ada979e 100644 --- a/mindspore/core/ops/less_equal.cc +++ b/mindspore/core/ops/less_equal.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto equal_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(equal_prim); - auto op_name = equal_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/logical_and.cc b/mindspore/core/ops/logical_and.cc index 4db682377e..3180aba6f0 100644 --- a/mindspore/core/ops/logical_and.cc +++ b/mindspore/core/ops/logical_and.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto logicaland_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(logicaland_prim); - auto op_name = logicaland_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/logical_not.cc b/mindspore/core/ops/logical_not.cc index f6dfc6d2b9..b21c6175d9 100644 --- a/mindspore/core/ops/logical_not.cc +++ b/mindspore/core/ops/logical_not.cc @@ -24,18 +24,14 @@ namespace ops { namespace { abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto LogicalNot_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(LogicalNot_prim); - auto op_name = LogicalNot_prim->name(); + auto op_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); return std::make_shared(in_shape); } TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); - auto LogicalNot_prim = prim->cast(); - MS_EXCEPTION_IF_NULL(LogicalNot_prim); - auto op_name = LogicalNot_prim->name(); + auto op_name = prim->name(); auto infer_dtype = input_args[0]->BuildType(); std::set local_bool = {kBool}; return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name); diff --git a/mindspore/core/ops/logical_or.cc b/mindspore/core/ops/logical_or.cc index e908ebb4b5..732ad342ad 100644 --- a/mindspore/core/ops/logical_or.cc +++ b/mindspore/core/ops/logical_or.cc @@ -29,9 +29,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto logicalor_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(logicalor_prim); - auto op_name = logicalor_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/lrn.cc b/mindspore/core/ops/lrn.cc index 7760800631..4201bc865d 100644 --- a/mindspore/core/ops/lrn.cc +++ b/mindspore/core/ops/lrn.cc @@ -77,9 +77,7 @@ void LRN::Init(const int64_t depth_radius, const float bias, const float alpha, namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto lrn_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(lrn_prim); - auto prim_name = lrn_prim->name(); + auto prim_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name); return std::make_shared(in_shape); diff --git a/mindspore/core/ops/lsh_projection.cc b/mindspore/core/ops/lsh_projection.cc index dd9d1b8ebe..9125f36dac 100644 --- a/mindspore/core/ops/lsh_projection.cc +++ b/mindspore/core/ops/lsh_projection.cc @@ -26,20 +26,12 @@ void LshProjection::set_type(const LshProjectionType &type) { AddAttr(kType, MakeValue(swi)); } -LshProjectionType LshProjection::get_type() const { - auto value_ptr = GetAttr(kType); - return LshProjectionType(GetValue(value_ptr)); -} +LshProjectionType LshProjection::get_type() const { return LshProjectionType(GetValue(GetAttr(kType))); } AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto LshProjection_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(LshProjection_prim); - // if (input_args.size() != 2 && input_args.size() != 3) { - // MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << input_args.size() << " is given."; - // } - auto op_name = LshProjection_prim->name(); + auto op_name = primitive->name(); auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name); CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); @@ -53,7 +45,7 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr } std::vector out_shape; - switch ((int64_t)LshProjection_prim->get_type()) { + switch ((int64_t)LshProjectionType(GetValue(primitive->GetAttr(kType)))) { case (int64_t)LshProjectionType::SPARSE: out_shape.push_back(input0[0]); break; diff --git a/mindspore/core/ops/lstm.cc b/mindspore/core/ops/lstm.cc index a42abf1be5..6b1985da9e 100644 --- a/mindspore/core/ops/lstm.cc +++ b/mindspore/core/ops/lstm.cc @@ -19,28 +19,34 @@ namespace mindspore { namespace ops { namespace { +int64_t get_good_ld(const int64_t dim, const int64_t type_size) { + int64_t ld = ((dim + (64 / type_size) - 1) / (64 / type_size)) * (64 / type_size); + if (ld * 256 == 0) { + return ld + 64 / type_size; + } + return ld; +} + AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector &input_args) { // infer shape MS_EXCEPTION_IF_NULL(primitive); - auto lstm_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(lstm_prim); - auto prim_name = lstm_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name); auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name); auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name); - int64_t input_x_size = lstm_prim->get_input_size(); + int64_t input_x_size = GetValue(primitive->GetAttr(kInput_size)); CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name); CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name); CheckAndConvertUtils::CheckInteger("h_shape.size()", h_input_shape.size(), kEqual, 3, prim_name); - CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, lstm_prim->name()); + CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, prim_name); - int64_t num_layers = lstm_prim->get_num_layers(); - int64_t num_directions = lstm_prim->get_num_directions(); - int64_t hidden_size = lstm_prim->get_hidden_size(); - int64_t input_size = lstm_prim->get_input_size(); + int64_t num_layers = GetValue(primitive->GetAttr(kNumLayers)); + int64_t num_directions = GetValue(primitive->GetAttr(kNumDirections)); + int64_t hidden_size = GetValue(primitive->GetAttr(kHidden_size)); + int64_t input_size = input_x_size; CheckAndConvertUtils::CheckInteger("h_shape[0]", h_input_shape[0], kEqual, num_layers * num_directions, prim_name); CheckAndConvertUtils::CheckInteger("h_shape[1]", h_input_shape[1], kEqual, x_input_shape[1], prim_name); CheckAndConvertUtils::CheckInteger("h_shape[2]", h_input_shape[2], kEqual, hidden_size, prim_name); @@ -48,8 +54,8 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector y_shape = {x_input_shape[0], x_input_shape[1], hidden_size * num_directions}; int64_t type_size = 4; - int64_t gates_ws_ld = lstm_prim->get_good_ld(hidden_size * 4, type_size); - int64_t states_ws_ld = lstm_prim->get_good_ld(std::max(hidden_size, input_size), type_size); + int64_t gates_ws_ld = get_good_ld(hidden_size * 4, type_size); + int64_t states_ws_ld = get_good_ld(std::max(hidden_size, input_size), type_size); int64_t ws_gates_size = num_layers * num_directions * x_input_shape[0] * x_input_shape[1] * gates_ws_ld * type_size; int64_t ws_states_size = (num_layers + 1) * num_directions * (x_input_shape[0] + 1) * x_input_shape[1] * states_ws_ld * type_size; @@ -99,26 +105,17 @@ void LSTM::set_input_size(const int64_t input_size) { CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name()); AddAttr(kInput_size, MakeValue(input_size)); } -int64_t LSTM::get_input_size() const { - auto value_ptr = this->GetAttr(kInput_size); - return GetValue(value_ptr); -} +int64_t LSTM::get_input_size() const { return GetValue(GetAttr(kInput_size)); } void LSTM::set_hidden_size(const int64_t hidden_size) { CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name()); AddAttr(kHidden_size, MakeValue(hidden_size)); } -int64_t LSTM::get_hidden_size() const { - auto value_ptr = this->GetAttr(kHidden_size); - return GetValue(value_ptr); -} +int64_t LSTM::get_hidden_size() const { return GetValue(GetAttr(kHidden_size)); } void LSTM::set_num_layers(const int64_t num_layers) { CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name()); AddAttr(kNumLayers, MakeValue(num_layers)); } -int64_t LSTM::get_num_layers() const { - auto value_ptr = this->GetAttr(kNumLayers); - return GetValue(value_ptr); -} +int64_t LSTM::get_num_layers() const { return GetValue(GetAttr(kNumLayers)); } void LSTM::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); } bool LSTM::get_has_bias() const { auto value_ptr = this->GetAttr(kHasBias); @@ -138,10 +135,7 @@ bool LSTM::get_bidirectional() const { return GetValue(value_ptr); } void LSTM::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); } -int64_t LSTM::get_num_directions() const { - auto value_ptr = this->GetAttr(kNumDirections); - return GetValue(value_ptr); -} +int64_t LSTM::get_num_directions() const { return GetValue(GetAttr(kNumDirections)); } void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); } float LSTM::get_zoneout_cell() const { return GetValue(this->GetAttr(kZoneoutCell)); } @@ -167,14 +161,6 @@ void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64 this->set_zoneout_hidden(zoneout_hidden); } -int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) { - int64_t ld = ((dim + (64 / type_size) - 1) / (64 / type_size)) * (64 / type_size); - if (ld * 256 == 0) { - return ld + 64 / type_size; - } - return ld; -} - AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { return std::make_shared(LstmInfer(primitive, input_args)); diff --git a/mindspore/core/ops/matrix_diag.cc b/mindspore/core/ops/matrix_diag.cc index a97b1536ad..c8ed95fd2e 100644 --- a/mindspore/core/ops/matrix_diag.cc +++ b/mindspore/core/ops/matrix_diag.cc @@ -29,9 +29,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto matrixdiag_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(matrixdiag_prim); - auto prim_name = matrixdiag_prim->name(); + auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); auto assist_shape = CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name); diff --git a/mindspore/core/ops/max_pool.cc b/mindspore/core/ops/max_pool.cc index f32b7466ed..22201079b0 100644 --- a/mindspore/core/ops/max_pool.cc +++ b/mindspore/core/ops/max_pool.cc @@ -31,37 +31,25 @@ void MaxPool::set_pad_mode(const PadMode &pad_mode) { this->AddAttr(kPadMode, MakeValue(swi)); } -PadMode MaxPool::get_pad_mode() const { - auto value_ptr = GetAttr(kPadMode); - return PadMode(GetValue(value_ptr)); -} +PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue(GetAttr(kPadMode))); } void MaxPool::set_kernel_size(const std::vector &kernel_size) { this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); } -std::vector MaxPool::get_kernel_size() const { - auto value_ptr = GetAttr(kKernelSize); - return GetValue>(value_ptr); -} +std::vector MaxPool::get_kernel_size() const { return GetValue>(GetAttr(kKernelSize)); } void MaxPool::set_strides(const std::vector &strides) { this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); } -std::vector MaxPool::get_strides() const { - auto value_ptr = GetAttr(kStrides); - return GetValue>(value_ptr); -} +std::vector MaxPool::get_strides() const { return GetValue>(GetAttr(kStrides)); } void MaxPool::set_format(const Format &format) { int64_t f = format; this->AddAttr(kFormat, MakeValue(f)); } -Format MaxPool::get_format() const { - auto value_ptr = GetAttr(kFormat); - return Format(GetValue(value_ptr)); -} +Format MaxPool::get_format() const { return Format(GetValue(GetAttr(kFormat))); } void MaxPool::set_pad(const std::vector &pad) { this->AddAttr(kPad, MakeValue(pad)); } diff --git a/mindspore/core/ops/maximum.cc b/mindspore/core/ops/maximum.cc index b212667203..fb37e72006 100644 --- a/mindspore/core/ops/maximum.cc +++ b/mindspore/core/ops/maximum.cc @@ -25,9 +25,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto maximum_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(maximum_prim); - auto op_name = maximum_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/merge.cc b/mindspore/core/ops/merge.cc index d20d2279d2..700e1b2309 100644 --- a/mindspore/core/ops/merge.cc +++ b/mindspore/core/ops/merge.cc @@ -28,9 +28,7 @@ namespace ops { AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto Merge_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(Merge_prim); - auto op_name = Merge_prim->name(); + auto op_name = primitive->name(); auto inputs_type = input_args[0]->BuildType()->cast()->elements(); auto inputs_shape = input_args[0]->BuildShape()->cast()->shape(); std::map args; diff --git a/mindspore/core/ops/mfcc.cc b/mindspore/core/ops/mfcc.cc index 73309e63ac..c52bec1a10 100644 --- a/mindspore/core/ops/mfcc.cc +++ b/mindspore/core/ops/mfcc.cc @@ -24,16 +24,15 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto mfcc_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(mfcc_prim); - auto prim_name = mfcc_prim->name(); + auto prim_name = primitive->name(); auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); - std::vector out_shape = {first_input_shape[0], first_input_shape[1], mfcc_prim->get_dct_coeff_num()}; + std::vector out_shape = {first_input_shape[0], first_input_shape[1], + GetValue(primitive->GetAttr(kDctCoeffNum))}; return std::make_shared(out_shape); } @@ -83,10 +82,7 @@ int64_t Mfcc::get_filter_bank_channel_num() const { void Mfcc::set_dct_coeff_num(const int64_t dct_coeff_num) { this->AddAttr(kDctCoeffNum, MakeValue(dct_coeff_num)); } -int64_t Mfcc::get_dct_coeff_num() const { - auto value_ptr = this->GetAttr(kDctCoeffNum); - return GetValue(value_ptr); -} +int64_t Mfcc::get_dct_coeff_num() const { return GetValue(GetAttr(kDctCoeffNum)); } AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { diff --git a/mindspore/core/ops/non_max_suppression.cc b/mindspore/core/ops/non_max_suppression.cc index 501adcb224..c0a289dcac 100644 --- a/mindspore/core/ops/non_max_suppression.cc +++ b/mindspore/core/ops/non_max_suppression.cc @@ -31,9 +31,6 @@ void NonMaxSuppression::Init(const int64_t center_point_box) { this->set_center_ AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto non_max_suppression_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(non_max_suppression_prim); MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime."; return std::make_shared(kInt32, std::vector{}); } diff --git a/mindspore/core/ops/one_hot.cc b/mindspore/core/ops/one_hot.cc index d1e2c11410..824b0810a1 100644 --- a/mindspore/core/ops/one_hot.cc +++ b/mindspore/core/ops/one_hot.cc @@ -25,17 +25,12 @@ namespace ops { void OneHot::Init(const int64_t axis) { this->set_axis(axis); } void OneHot::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); } -int64_t OneHot::get_axis() const { - auto value_ptr = this->GetAttr(kAxis); - return GetValue(value_ptr); -} +int64_t OneHot::get_axis() const { return GetValue(GetAttr(kAxis)); } namespace { abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto OneHot_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(OneHot_prim); - auto op_name = OneHot_prim->name(); - int64_t axis = OneHot_prim->get_axis(); + auto op_name = primitive->name(); + int64_t axis = GetValue(primitive->GetAttr(kAxis)); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); auto depth_val = GetValue(input_args[1]->BuildValue()); @@ -50,9 +45,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); - auto OneHot_prim = prim->cast(); - MS_EXCEPTION_IF_NULL(OneHot_prim); - auto op_name = OneHot_prim->name(); + auto op_name = prim->name(); CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name); CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name); std::map args = {{"on_value", input_args[2]->BuildType()}, diff --git a/mindspore/core/ops/ones_like.cc b/mindspore/core/ops/ones_like.cc index e1c62d2f6a..1a02b5e6a1 100644 --- a/mindspore/core/ops/ones_like.cc +++ b/mindspore/core/ops/ones_like.cc @@ -28,9 +28,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto OnesLike_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(OnesLike_prim); - auto prim_name = OnesLike_prim->name(); + auto prim_name = primitive->name(); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); return std::make_shared(input_shape); diff --git a/mindspore/core/ops/pack.cc b/mindspore/core/ops/pack.cc index e83a83dd44..ec31ed0dd7 100644 --- a/mindspore/core/ops/pack.cc +++ b/mindspore/core/ops/pack.cc @@ -50,23 +50,18 @@ std::vector _get_pack_shape(std::vector x_shapes, std::ve void Pack::set_axis(const int64_t &axis) { AddAttr(kAxis, MakeValue(axis)); } -int64_t Pack::get_axis() const { - auto value_ptr = this->GetAttr(kAxis); - return GetValue(value_ptr); -} +int64_t Pack::get_axis() const { return GetValue(GetAttr(kAxis)); } void Pack::Init(const int64_t &axis) { this->set_axis(axis); } AbstractBasePtr PackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pack_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pack_prim); - auto prim_name = pack_prim->name(); + auto prim_name = primitive->name(); auto x_shapes = input_args[0]->BuildShape()->cast()->shape(); auto x_types = input_args[0]->BuildType()->cast()->elements(); - auto all_shape = _get_pack_shape(x_shapes, x_types, pack_prim->get_axis(), prim_name); + auto all_shape = _get_pack_shape(x_shapes, x_types, GetValue(primitive->GetAttr(kAxis)), prim_name); auto tensor_type = x_types[0]->cast(); MS_EXCEPTION_IF_NULL(tensor_type); auto data_type = tensor_type->element(); diff --git a/mindspore/core/ops/pad.cc b/mindspore/core/ops/pad.cc index 5961fd7486..95a5be8fd1 100644 --- a/mindspore/core/ops/pad.cc +++ b/mindspore/core/ops/pad.cc @@ -23,10 +23,8 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pad_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pad_prim); - auto prim_name = pad_prim->name(); - auto paddings_attr = pad_prim->get_paddings(); + auto prim_name = primitive->name(); + auto paddings_attr = GetValue>>(primitive->GetAttr(kPaddings)); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad"); CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()), prim_name); @@ -59,8 +57,7 @@ void Pad::set_paddings(const std::vector> &paddings) { this->AddAttr(kPaddings, MakeValue(paddings)); } std::vector> Pad::get_paddings() const { - auto value_ptr = GetAttr(kPaddings); - return GetValue>>(value_ptr); + return GetValue>>(GetAttr(kPaddings)); } AbstractBasePtr PadInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { diff --git a/mindspore/core/ops/pow.cc b/mindspore/core/ops/pow.cc index 510c97bd4e..724857d439 100644 --- a/mindspore/core/ops/pow.cc +++ b/mindspore/core/ops/pow.cc @@ -24,9 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto pow_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(pow_prim); - auto op_name = pow_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/prior_box.cc b/mindspore/core/ops/prior_box.cc index edb3057226..fad2b279a7 100644 --- a/mindspore/core/ops/prior_box.cc +++ b/mindspore/core/ops/prior_box.cc @@ -24,10 +24,7 @@ namespace mindspore { namespace ops { void PriorBox::set_min_sizes(const std::vector &min_sizes) { this->AddAttr(kMinSizes, MakeValue(min_sizes)); } -std::vector PriorBox::get_min_sizes() const { - auto value_ptr = GetAttr(kMinSizes); - return GetValue>(value_ptr); -} +std::vector PriorBox::get_min_sizes() const { return GetValue>(GetAttr(kMinSizes)); } void PriorBox::set_max_sizes(const std::vector &max_sizes) { this->AddAttr(kMaxSizes, MakeValue(max_sizes)); } @@ -40,10 +37,7 @@ void PriorBox::set_aspect_ratios(const std::vector &aspect_ratios) { this->AddAttr(kAspectRatios, MakeValue(aspect_ratios)); } -std::vector PriorBox::get_aspect_ratios() const { - auto value_ptr = GetAttr(kAspectRatios); - return GetValue>(value_ptr); -} +std::vector PriorBox::get_aspect_ratios() const { return GetValue>(GetAttr(kAspectRatios)); } void PriorBox::set_variances(const std::vector &variances) { this->AddAttr(kVariances, MakeValue(variances)); } @@ -89,10 +83,7 @@ bool PriorBox::get_clip() const { void PriorBox::set_flip(const bool flip) { this->AddAttr(kFlip, MakeValue(flip)); } -bool PriorBox::get_flip() const { - auto value_ptr = GetAttr(kFlip); - return GetValue(value_ptr); -} +bool PriorBox::get_flip() const { return GetValue(GetAttr(kFlip)); } void PriorBox::set_offset(const float offset) { this->AddAttr(kOffset, MakeValue(offset)); } @@ -121,25 +112,23 @@ void PriorBox::Init(const std::vector &min_sizes, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto PriorBox_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(PriorBox_prim); - auto op_name = PriorBox_prim->name(); + auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]); std::vector different_aspect_ratios{1.0f}; - auto aspect_ratios = PriorBox_prim->get_aspect_ratios(); + auto aspect_ratios = GetValue>(primitive->GetAttr(kAspectRatios)); for (int64_t i = 0; i < (int64_t)aspect_ratios.size(); i++) { float ratio = aspect_ratios[i]; bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(), [&](float v) { return abs(ratio - v) < 1e-6; }); if (!exist) { different_aspect_ratios.emplace_back(ratio); - if (PriorBox_prim->get_flip()) { + if (GetValue(primitive->GetAttr(kFlip))) { different_aspect_ratios.emplace_back(1.0f / ratio); } } } - int64_t num_priors_box = - PriorBox_prim->get_min_sizes().size() * different_aspect_ratios.size() + PriorBox_prim->get_max_sizes().size(); + auto min_sizes = GetValue>(primitive->GetAttr(kMinSizes)); + int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size(); auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); int64_t h = input[0] * input[1] * num_priors_box * 4; std::vector output_shape{1, h, 1, 2}; diff --git a/mindspore/core/ops/quant_dtype_cast.cc b/mindspore/core/ops/quant_dtype_cast.cc index 59d002e999..aee60bd0a9 100644 --- a/mindspore/core/ops/quant_dtype_cast.cc +++ b/mindspore/core/ops/quant_dtype_cast.cc @@ -24,10 +24,7 @@ int64_t QuantDTypeCast::get_src_t() const { return GetValue(value_ptr); } void QuantDTypeCast::set_dst_t(const int64_t dst_t) { AddAttr(kDstT, MakeValue(dst_t)); } -int64_t QuantDTypeCast::get_dst_t() const { - auto value_ptr = this->GetAttr(kDstT); - return GetValue(value_ptr); -} +int64_t QuantDTypeCast::get_dst_t() const { return GetValue(GetAttr(kDstT)); } void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) { this->set_src_t(src_t); this->set_dst_t(dst_t); @@ -35,16 +32,14 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) { AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto QuantDTypeCast_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(QuantDTypeCast_prim); - auto op_name = QuantDTypeCast_prim->name(); + auto op_name = primitive->name(); MS_EXCEPTION_IF_NULL(input_args[0]); auto input_type = input_args[0]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(input_type); - MS_ASSERT(input_type->element() == TypeIdToType(TypeId(QuantDTypeCast_prim->get_dst_t()))); + auto dst_type = GetValue(primitive->GetAttr(kDstT)); + MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type))); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name); - return std::make_shared(TypeIdToType(TypeId(QuantDTypeCast_prim->get_dst_t())), - input_shape); + return std::make_shared(TypeIdToType(TypeId(dst_type)), input_shape); } REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast); } // namespace ops diff --git a/mindspore/core/ops/range.cc b/mindspore/core/ops/range.cc index f28951d61e..ca56e34ee9 100644 --- a/mindspore/core/ops/range.cc +++ b/mindspore/core/ops/range.cc @@ -34,10 +34,7 @@ int64_t Range::get_d_type() const { void Range::set_start(const int64_t start) { this->AddAttr(kStart, MakeValue(start)); } -int64_t Range::get_start() const { - auto value_ptr = GetAttr(kStart); - return GetValue(value_ptr); -} +int64_t Range::get_start() const { return GetValue(GetAttr(kStart)); } void Range::set_limit(const int64_t limit) { this->AddAttr(kLimit, MakeValue(limit)); } @@ -63,10 +60,7 @@ void Range::Init(const int64_t d_type, const int64_t start, const int64_t limit, AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(prim); int64_t shape_size = 0; - TypeId dtype; if (input_args.size() == 3) { MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); MS_EXCEPTION_IF_NULL(input_args[1]->BuildValue()); @@ -74,7 +68,7 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP auto start_tensor = input_args[0]->BuildValue()->cast(); auto limit_tensor = input_args[1]->BuildValue()->cast(); auto delta_tensor = input_args[2]->BuildValue()->cast(); - dtype = static_cast(start_tensor->data_type_c()); + auto dtype = start_tensor->data_type(); switch (dtype) { case kNumberTypeInt: case kNumberTypeInt32: { @@ -97,9 +91,9 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP } } } else { - int64_t start = prim->get_start(); - int64_t limit = prim->get_limit(); - int64_t delta = prim->get_delta(); + int64_t start = GetValue(primitive->GetAttr(kStart)); + int64_t limit = GetValue(primitive->GetAttr(kLimit)); + int64_t delta = GetValue(primitive->GetAttr(kDelta)); shape_size = std::max(static_cast(std::ceil(LongToDouble(limit - start) / delta)), static_cast(0)); } diff --git a/mindspore/core/ops/rank.cc b/mindspore/core/ops/rank.cc index b10e324bee..b969ec1fd0 100644 --- a/mindspore/core/ops/rank.cc +++ b/mindspore/core/ops/rank.cc @@ -21,9 +21,7 @@ namespace ops { namespace { TypePtr RankInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); - auto Rank_prim = prim->cast(); - MS_EXCEPTION_IF_NULL(Rank_prim); - auto op_name = Rank_prim->name(); + auto op_name = prim->name(); auto infer_dtype = input_args[0]->BuildType(); CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name); return kTypeNone; diff --git a/mindspore/core/ops/reciprocal.cc b/mindspore/core/ops/reciprocal.cc index 35d3202036..bca8df361c 100644 --- a/mindspore/core/ops/reciprocal.cc +++ b/mindspore/core/ops/reciprocal.cc @@ -29,9 +29,7 @@ namespace ops { AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto reciprocal_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(reciprocal_prim); - auto prim_name = reciprocal_prim->name(); + auto prim_name = primitive->name(); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/ops/reduce.cc b/mindspore/core/ops/reduce.cc index cad04f92e9..e037a21f9b 100644 --- a/mindspore/core/ops/reduce.cc +++ b/mindspore/core/ops/reduce.cc @@ -70,13 +70,11 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorBuildValue(); MS_EXCEPTION_IF_NULL(primitive); - auto reduce_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(reduce_prim); - auto prim_name = reduce_prim->name(); + auto prim_name = primitive->name(); auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name); - auto keep_dims = reduce_prim->get_keep_dims(); + auto keep_dims = GetValue(primitive->GetAttr(kKeepDims)); auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name); return std::make_shared(out_shape); @@ -93,10 +91,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & void Reduce::set_keep_dims(const bool keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); } -bool Reduce::get_keep_dims() const { - auto value_ptr = GetAttr(kKeepDims); - return GetValue(value_ptr); -} +bool Reduce::get_keep_dims() const { return GetValue(GetAttr(kKeepDims)); } void Reduce::Init(const bool keep_dims) { this->set_keep_dims(keep_dims); } diff --git a/mindspore/core/ops/resize_bilinear.cc b/mindspore/core/ops/resize_bilinear.cc index 4d9969ce64..5aca24e120 100644 --- a/mindspore/core/ops/resize_bilinear.cc +++ b/mindspore/core/ops/resize_bilinear.cc @@ -27,10 +27,7 @@ namespace mindspore { namespace ops { void ResizeBilinear::set_size(const std::vector &size) { this->AddAttr(kSize, MakeValue(size)); } -std::vector ResizeBilinear::get_size() const { - auto value_ptr = GetAttr(kSize); - return GetValue>(value_ptr); -} +std::vector ResizeBilinear::get_size() const { return GetValue>(GetAttr(kSize)); } void ResizeBilinear::set_align_corners(const bool align_corners) { this->AddAttr(kAlignCorners, MakeValue(align_corners)); @@ -48,9 +45,7 @@ void ResizeBilinear::Init(const std::vector &size, const bool align_cor AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto resize_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(resize_prim); - auto prim_name = resize_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name); // Infer shape @@ -58,7 +53,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name); std::vector out_shape = {input_shape[0], input_shape[1]}; - auto size = resize_prim->get_size(); + auto size = GetValue>(primitive->GetAttr(kSize)); out_shape.insert(out_shape.end(), size.begin(), size.end()); // Infer type diff --git a/mindspore/core/ops/reverse_sequence.cc b/mindspore/core/ops/reverse_sequence.cc index b770c118f0..ead5d7c835 100644 --- a/mindspore/core/ops/reverse_sequence.cc +++ b/mindspore/core/ops/reverse_sequence.cc @@ -30,10 +30,7 @@ void ReverseSequence::Init(const int64_t seq_dim, const int64_t batch_dim) { void ReverseSequence::set_seq_dim(const int64_t seq_dim) { this->AddAttr(kSeqDim, MakeValue(seq_dim)); } void ReverseSequence::set_batch_dim(const int64_t batch_dim) { this->AddAttr(kBatchDim, MakeValue(batch_dim)); } -int64_t ReverseSequence::get_seq_dim() const { - auto value_ptr = this->GetAttr(kSeqDim); - return GetValue(value_ptr); -} +int64_t ReverseSequence::get_seq_dim() const { return GetValue(GetAttr(kSeqDim)); } int64_t ReverseSequence::get_batch_dim() const { auto value_ptr = this->GetAttr(kBatchDim); return GetValue(value_ptr); @@ -41,9 +38,7 @@ int64_t ReverseSequence::get_batch_dim() const { AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto reverse_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(reverse_prim); - auto prim_name = reverse_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); @@ -53,8 +48,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); auto seq_lengths = CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name); - auto seq_dim = reverse_prim->get_seq_dim(); - auto batch_dim = reverse_prim->get_batch_dim(); + auto seq_dim = GetValue(primitive->GetAttr(kSeqDim)); + auto batch_dim = GetValue(primitive->GetAttr(kBatchDim)); CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name); CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kLessEqual, input_shape.size(), prim_name); CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kNotEqual, seq_dim, prim_name); diff --git a/mindspore/core/ops/reverse_v2.cc b/mindspore/core/ops/reverse_v2.cc index 3aa2e74f46..a247172d95 100644 --- a/mindspore/core/ops/reverse_v2.cc +++ b/mindspore/core/ops/reverse_v2.cc @@ -24,9 +24,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto reverseV2_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(reverseV2_prim); - auto prim_name = reverseV2_prim->name(); + auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); return std::make_shared(x_shape); } diff --git a/mindspore/core/ops/rfft.cc b/mindspore/core/ops/rfft.cc index 3dc04871e3..2e38650c8a 100644 --- a/mindspore/core/ops/rfft.cc +++ b/mindspore/core/ops/rfft.cc @@ -24,13 +24,11 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto rfft_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(rfft_prim); - auto prim_name = rfft_prim->name(); + auto prim_name = primitive->name(); auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name); auto out_shape = first_input_shape; - out_shape[out_shape.size() - 1] = rfft_prim->get_fft_length() / 2 + 1; + out_shape[out_shape.size() - 1] = GetValue(primitive->GetAttr(kFftLength)) / 2 + 1; out_shape.push_back(2); return std::make_shared(out_shape); } @@ -47,10 +45,7 @@ void Rfft::Init(const int64_t fft_length) { this->set_fft_length(fft_length); } void Rfft::set_fft_length(const int64_t fft_length) { this->AddAttr(kFftLength, MakeValue(fft_length)); } -int64_t Rfft::get_fft_length() const { - auto value_ptr = this->GetAttr(kFftLength); - return GetValue(value_ptr); -} +int64_t Rfft::get_fft_length() const { return GetValue(GetAttr(kFftLength)); } AbstractBasePtr RfftInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { diff --git a/mindspore/core/ops/roi_pooling.cc b/mindspore/core/ops/roi_pooling.cc index 7b421c6ed4..d0fa2cc11b 100644 --- a/mindspore/core/ops/roi_pooling.cc +++ b/mindspore/core/ops/roi_pooling.cc @@ -27,10 +27,7 @@ namespace mindspore { namespace ops { void ROIPooling::set_pooled_h(const int64_t pooled_h) { this->AddAttr(kPooledH, MakeValue(pooled_h)); } -int64_t ROIPooling::get_pooled_h() const { - auto value_ptr = GetAttr(kPooledH); - return GetValue(value_ptr); -} +int64_t ROIPooling::get_pooled_h() const { return GetValue(GetAttr(kPooledH)); } void ROIPooling::set_pooled_w(const int64_t pooled_w) { this->AddAttr(kPooledW, MakeValue(pooled_w)); } @@ -54,9 +51,7 @@ void ROIPooling::Init(const int64_t pooled_h, const int64_t pooled_w, const floa AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto roi_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(roi_prim); - auto prim_name = roi_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("roi_pooling_infer", input_args.size(), kEqual, 2, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[1]); @@ -65,8 +60,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi auto output_data_type = input_args[0]->BuildType()->cast()->element(); // Infer shape - auto new_h = roi_prim->get_pooled_h(); - auto new_w = roi_prim->get_pooled_w(); + auto new_h = GetValue(primitive->GetAttr(kPooledH)); + auto new_w = GetValue(primitive->GetAttr(kPooledW)); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name); diff --git a/mindspore/core/ops/rsqrt.cc b/mindspore/core/ops/rsqrt.cc index c1d84f9b23..b9bab65707 100644 --- a/mindspore/core/ops/rsqrt.cc +++ b/mindspore/core/ops/rsqrt.cc @@ -29,9 +29,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto rsqrt_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(rsqrt_prim); - auto prim_name = rsqrt_prim->name(); + auto prim_name = primitive->name(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name); CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); return std::make_shared(in_shape); diff --git a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc index 83144e7bc5..d316d0c521 100644 --- a/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc @@ -29,9 +29,7 @@ namespace ops { AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto sigmoid_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(sigmoid_prim); - auto prim_name = sigmoid_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", input_args.size(), kEqual, 2, prim_name); diff --git a/mindspore/core/ops/skip_gram.cc b/mindspore/core/ops/skip_gram.cc index e4375fd202..59c3737caf 100644 --- a/mindspore/core/ops/skip_gram.cc +++ b/mindspore/core/ops/skip_gram.cc @@ -23,9 +23,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto SkipGram_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(SkipGram_prim); - auto prim_name = SkipGram_prim->name(); + auto prim_name = primitive->name(); if (input_args.size() != 1) { MS_LOG(ERROR) << "Skip Gram should have one input"; } diff --git a/mindspore/core/ops/smooth_l1_loss.cc b/mindspore/core/ops/smooth_l1_loss.cc index f43b58c708..ee908b7ea8 100644 --- a/mindspore/core/ops/smooth_l1_loss.cc +++ b/mindspore/core/ops/smooth_l1_loss.cc @@ -36,9 +36,7 @@ float SmoothL1Loss::get_beta() const { AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto smooth_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(smooth_prim); - auto prim_name = smooth_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name); // Infer shape diff --git a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc index 3a05576720..2c0dadaeb0 100644 --- a/mindspore/core/ops/softmax_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/softmax_cross_entropy_with_logits.cc @@ -29,9 +29,7 @@ namespace ops { AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto softmax_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(softmax_prim); - auto prim_name = softmax_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", input_args.size(), kEqual, 2, prim_name); diff --git a/mindspore/core/ops/space_to_batch.cc b/mindspore/core/ops/space_to_batch.cc index bf0ca42118..93aaa99041 100644 --- a/mindspore/core/ops/space_to_batch.cc +++ b/mindspore/core/ops/space_to_batch.cc @@ -28,15 +28,13 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto spacetobatch_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(spacetobatch_prim); - auto prim_name = spacetobatch_prim->name(); + auto prim_name = primitive->name(); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); std::vector output_shape(input_shape.size()); - auto block_shape_vector = spacetobatch_prim->get_block_size(); - auto paddings = spacetobatch_prim->get_paddings(); + auto block_shape_vector = GetValue>(primitive->GetAttr(kBlockSize)); + auto paddings = GetValue>>(primitive->GetAttr(kPaddings)); for (size_t i = 0; i < 2; i++) { auto padded = output_shape[i + 2] + paddings[i][0] + paddings[i][1]; CheckAndConvertUtils::CheckInteger("padded shape", padded % block_shape_vector.size(), kEqual, 0, prim_name); @@ -77,8 +75,7 @@ void SpaceToBatch::set_block_size(const std::vector block_size) { } std::vector SpaceToBatch::get_block_size() const { - auto value_ptr = GetAttr(kBlockSize); - return GetValue>(value_ptr); + return GetValue>(GetAttr(kBlockSize)); } void SpaceToBatch::Init(const std::vector block_size, const std::vector> &paddings) { diff --git a/mindspore/core/ops/space_to_batch_nd.cc b/mindspore/core/ops/space_to_batch_nd.cc index 4117c15d4a..2075ba7085 100644 --- a/mindspore/core/ops/space_to_batch_nd.cc +++ b/mindspore/core/ops/space_to_batch_nd.cc @@ -28,16 +28,14 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto space_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(space_prim); - auto prim_name = space_prim->name(); + auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); auto out_shape = x_shape; int64_t block_shape_prod = 1; const int64_t offset = 2; - auto block_shape = space_prim->get_block_shape(); - auto padding = space_prim->get_paddings(); + auto block_shape = GetValue>(primitive->GetAttr(kBlockShape)); + auto padding = GetValue>>(primitive->GetAttr(kPaddings)); int64_t size = block_shape.size(); for (int64_t i = 0; i < size; i++) { int64_t padded = out_shape[i + offset] + padding[i][0] + padding[i][1]; @@ -87,8 +85,7 @@ void SpaceToBatchND::set_block_shape(std::vector block_shape) { } std::vector SpaceToBatchND::get_block_shape() const { - auto value_ptr = GetAttr(kBlockShape); - return GetValue>(value_ptr); + return GetValue>(GetAttr(kBlockShape)); } void SpaceToBatchND::Init(std::vector block_shape, std::vector> paddings) { diff --git a/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc index ce0633897b..2a61a73fff 100644 --- a/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc +++ b/mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc @@ -31,18 +31,13 @@ void SparseSoftmaxCrossEntropyWithLogits::set_is_grad(const bool is_grad) { this->AddAttr(kIsGrad, MakeValue(is_grad)); } -bool SparseSoftmaxCrossEntropyWithLogits::get_is_grad() const { - auto value_ptr = GetAttr(kIsGrad); - return GetValue(value_ptr); -} +bool SparseSoftmaxCrossEntropyWithLogits::get_is_grad() const { return GetValue(GetAttr(kIsGrad)); } AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto sparse_softmax_cross_entropy_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(sparse_softmax_cross_entropy_prim); - auto prim_name = sparse_softmax_cross_entropy_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); @@ -51,7 +46,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name); std::vector output_shape; - if (sparse_softmax_cross_entropy_prim->get_is_grad() != 0) { + if (GetValue(primitive->GetAttr(kIsGrad)) != 0) { output_shape = input_shape; } else { output_shape.push_back(1); diff --git a/mindspore/core/ops/sparse_to_dense.cc b/mindspore/core/ops/sparse_to_dense.cc index ea8fbf896f..e7f68e3648 100644 --- a/mindspore/core/ops/sparse_to_dense.cc +++ b/mindspore/core/ops/sparse_to_dense.cc @@ -27,9 +27,7 @@ namespace ops { AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto spasetodense_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(spasetodense_prim); - auto prim_name = spasetodense_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 3, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); diff --git a/mindspore/core/ops/squared_difference.cc b/mindspore/core/ops/squared_difference.cc index 672da25fbb..f7a9531181 100644 --- a/mindspore/core/ops/squared_difference.cc +++ b/mindspore/core/ops/squared_difference.cc @@ -27,9 +27,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto squared_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(squared_prim); - auto op_name = squared_prim->name(); + auto op_name = primitive->name(); return BroadCastInferShape(op_name, input_args); } diff --git a/mindspore/core/ops/squeeze.cc b/mindspore/core/ops/squeeze.cc index 6bc3971e56..4ca7b5c3c3 100644 --- a/mindspore/core/ops/squeeze.cc +++ b/mindspore/core/ops/squeeze.cc @@ -20,18 +20,13 @@ namespace mindspore { namespace ops { void Squeeze::Init(const std::vector &axis) { set_axis(axis); } void Squeeze::set_axis(const std::vector &axis) { AddAttr(kAxis, MakeValue(axis)); } -std::vector Squeeze::get_axis() const { - auto value_ptr = this->GetAttr(kAxis); - return GetValue>(value_ptr); -} +std::vector Squeeze::get_axis() const { return GetValue>(GetAttr(kAxis)); } namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto squeeze_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(squeeze_prim); - auto op_name = squeeze_prim->name(); - auto axis = squeeze_prim->get_axis(); + auto op_name = primitive->name(); + auto axis = GetValue>(primitive->GetAttr(kAxis)); std::vector infer_shape; auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); diff --git a/mindspore/core/ops/stack.cc b/mindspore/core/ops/stack.cc index 9396b31736..81c2870c6e 100644 --- a/mindspore/core/ops/stack.cc +++ b/mindspore/core/ops/stack.cc @@ -21,9 +21,7 @@ namespace ops { namespace { abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto stack_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(stack_prim); - auto prim_name = stack_prim->name(); + auto prim_name = primitive->name(); if (input_args.size() != 1) { MS_LOG(ERROR) << "Invalid output size:" << input_args.size(); @@ -46,7 +44,7 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v } } std::vector infer_shape = input_shape; - infer_shape.insert(infer_shape.begin() + stack_prim->get_axis(), input_args.size()); + infer_shape.insert(infer_shape.begin() + GetValue(primitive->GetAttr(kAxis)), input_args.size()); auto infer_type0 = input_args[0]->BuildType()->cast()->element(); for (int64_t i = 1; i < (int64_t)input_args.size(); i++) { @@ -64,10 +62,7 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v void Stack::set_axis(const int64_t axis) { AddAttr(kAxis, MakeValue(axis)); } -int64_t Stack::get_axis() const { - auto value_ptr = this->GetAttr(kAxis); - return GetValue(value_ptr); -} +int64_t Stack::get_axis() const { return GetValue(GetAttr(kAxis)); } void Stack::Init(const int64_t axis) { this->set_axis(axis); } diff --git a/mindspore/core/ops/strided_slice.cc b/mindspore/core/ops/strided_slice.cc index 680ee4ab90..b0346a2b44 100644 --- a/mindspore/core/ops/strided_slice.cc +++ b/mindspore/core/ops/strided_slice.cc @@ -28,12 +28,79 @@ namespace mindspore { namespace ops { namespace { +std::vector TenToTwo(int64_t num) { + std::vector output; + if (num == 0) { + output.push_back(0); + return output; + } + while (num) { + output.push_back(num % 2); + num /= 2; + } + + return output; +} + +int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, std::vector x_shape, + int64_t i) { + if (i > (int64_t)x_shape.size()) { + MS_EXCEPTION(ValueError) << "For 'StridedSlice', When their is no new axis, " + "the index length must be less or equal than the dim of x."; + } + int64_t x_dim = x_shape[i]; + int64_t slicing_length = 0; + if (strides > 0) { + if ((start_pos >= x_dim) || end_pos < -x_dim) { + slicing_length = 0; + } else { + if (-x_dim <= start_pos && start_pos < 0) { + start_pos += x_dim; + } + if (start_pos < -x_dim) { + start_pos = 0; + } + if (-x_dim <= end_pos && end_pos < 0) { + end_pos += x_dim; + } + if (end_pos > x_dim) { + end_pos = x_dim; + } + if (start_pos > end_pos) { + slicing_length = 0; + } else { + slicing_length = 1 + (end_pos - 1 - start_pos) / strides; + } + } + } else { + if (start_pos < -x_dim || end_pos >= x_dim) { + slicing_length = 0; + } else { + if (0 < start_pos && start_pos < x_dim) { + start_pos += -x_dim; + } + if (start_pos >= x_dim) { + start_pos = -1; + } + if (0 <= end_pos && end_pos < x_dim) { + end_pos += -x_dim; + } + if (end_pos < -x_dim - 1) { + end_pos = -x_dim - 1; + } + if (start_pos <= end_pos) { + slicing_length = 0; + } else { + slicing_length = 1 + (end_pos + 1 - start_pos) / strides; + } + } + } + return slicing_length; +} abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto strided_slice_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(strided_slice_prim); - auto prim_name = strided_slice_prim->name(); + auto prim_name = primitive->name(); auto temp_begin_v = input_args[1]->cast()->BuildValue(); auto begin_v = GetValue>(temp_begin_v); auto temp_end_v = input_args[2]->cast()->BuildValue(); @@ -44,11 +111,11 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); int64_t x_rank = x_shape.size(); int64_t slice_len = begin_v.size(); - std::vector begin_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_begin_mask()); - std::vector end_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_end_mask()); - std::vector ellipsis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_ellipsis_mask()); - std::vector new_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_new_axis_mask()); - std::vector shrink_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_shrink_axis_mask()); + std::vector begin_pos = TenToTwo(GetValue(primitive->GetAttr(kBeginMask))); + std::vector end_pos = TenToTwo(GetValue(primitive->GetAttr(kEndMask))); + std::vector ellipsis_pos = TenToTwo(GetValue(primitive->GetAttr(kEllipsisMask))); + std::vector new_axis_pos = TenToTwo(GetValue(primitive->GetAttr(kNewAxisMask))); + std::vector shrink_axis_pos = TenToTwo(GetValue(primitive->GetAttr(kShrinkAxisMask))); int64_t i = 0; int64_t j = 0; @@ -91,7 +158,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, finish = x_shape[0]; strides = 1; } - slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_shape, i); + slicing_length = compute_slicing_length(start, finish, strides, x_shape, i); infer_shape.push_back(slicing_length); i += 1; j += 1; @@ -132,7 +199,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, i += 1; continue; } - slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_shape, i); + slicing_length = compute_slicing_length(start, finish, strides, x_shape, i); infer_shape.push_back(slicing_length); i += 1; j += 1; @@ -154,10 +221,7 @@ void StridedSlice::set_begin_mask(const int64_t begin_mask) { CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name()); this->AddAttr(kBeginMask, MakeValue(begin_mask)); } -int64_t StridedSlice::get_begin_mask() const { - auto value_ptr = GetAttr(kBeginMask); - return GetValue(value_ptr); -} +int64_t StridedSlice::get_begin_mask() const { return GetValue(GetAttr(kBeginMask)); } void StridedSlice::set_end_mask(const int64_t end_mask) { CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name()); this->AddAttr(kEndMask, MakeValue(end_mask)); @@ -205,76 +269,6 @@ void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const this->set_shrink_axis_mask(shrink_axis_mask); } -std::vector StridedSlice::TenToTwo(int64_t num) { - std::vector output; - if (num == 0) { - output.push_back(0); - return output; - } - while (num) { - output.push_back(num % 2); - num /= 2; - } - - return output; -} - -int64_t StridedSlice::compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, - std::vector x_shape, int64_t i) { - if (i > (int64_t)x_shape.size()) { - MS_EXCEPTION(ValueError) << "For 'StridedSlice', When their is no new axis, " - "the index length must be less or equal than the dim of x."; - } - int64_t x_dim = x_shape[i]; - int64_t slicing_length = 0; - if (strides > 0) { - if ((start_pos >= x_dim) || end_pos < -x_dim) { - slicing_length = 0; - } else { - if (-x_dim <= start_pos && start_pos < 0) { - start_pos += x_dim; - } - if (start_pos < -x_dim) { - start_pos = 0; - } - if (-x_dim <= end_pos && end_pos < 0) { - end_pos += x_dim; - } - if (end_pos > x_dim) { - end_pos = x_dim; - } - if (start_pos > end_pos) { - slicing_length = 0; - } else { - slicing_length = 1 + (end_pos - 1 - start_pos) / strides; - } - } - } else { - if (start_pos < -x_dim || end_pos >= x_dim) { - slicing_length = 0; - } else { - if (0 < start_pos && start_pos < x_dim) { - start_pos += -x_dim; - } - if (start_pos >= x_dim) { - start_pos = -1; - } - if (0 <= end_pos && end_pos < x_dim) { - end_pos += -x_dim; - } - if (end_pos < -x_dim - 1) { - end_pos = -x_dim - 1; - } - if (start_pos <= end_pos) { - slicing_length = 0; - } else { - slicing_length = 1 + (end_pos + 1 - start_pos) / strides; - } - } - } - return slicing_length; -} - AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { return std::make_shared(StridedSliceInferType(primitive, input_args), diff --git a/mindspore/core/ops/sub.cc b/mindspore/core/ops/sub.cc index 0a6187126e..c45b15d440 100644 --- a/mindspore/core/ops/sub.cc +++ b/mindspore/core/ops/sub.cc @@ -29,9 +29,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto sub_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(sub_prim); - auto prim_name = sub_prim->name(); + auto prim_name = primitive->name(); return BroadCastInferShape(prim_name, input_args); } diff --git a/mindspore/core/ops/tan.cc b/mindspore/core/ops/tan.cc index 6d21bc1355..72ba1b4d80 100644 --- a/mindspore/core/ops/tan.cc +++ b/mindspore/core/ops/tan.cc @@ -29,9 +29,7 @@ namespace ops { AbstractBasePtr TanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto tan_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(tan_prim); - auto prim_name = tan_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name); // Infer Shape diff --git a/mindspore/core/ops/tensor_list_from_tensor.cc b/mindspore/core/ops/tensor_list_from_tensor.cc index b6c3e4169c..6b82a140bb 100644 --- a/mindspore/core/ops/tensor_list_from_tensor.cc +++ b/mindspore/core/ops/tensor_list_from_tensor.cc @@ -24,9 +24,7 @@ namespace { abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto tensor_list_from_tensor_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(tensor_list_from_tensor_prim); - auto prim_name = tensor_list_from_tensor_prim->name(); + auto prim_name = primitive->name(); auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0 shape", input_args[0]->BuildShape(), prim_name); auto input1_shape = diff --git a/mindspore/core/ops/tensor_list_stack.cc b/mindspore/core/ops/tensor_list_stack.cc index 054caa79f7..a701381652 100644 --- a/mindspore/core/ops/tensor_list_stack.cc +++ b/mindspore/core/ops/tensor_list_stack.cc @@ -49,12 +49,10 @@ int64_t TensorListStack::get_element_dtype() const { AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto TensorListStack_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(TensorListStack_prim); - for (auto input : input_args) { + for (const auto &input : input_args) { MS_EXCEPTION_IF_NULL(input); } - auto op_name = TensorListStack_prim->name(); + auto op_name = primitive->name(); auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name); int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies()); diff --git a/mindspore/core/ops/tile.cc b/mindspore/core/ops/tile.cc index 76910c4eba..fa3580d24e 100644 --- a/mindspore/core/ops/tile.cc +++ b/mindspore/core/ops/tile.cc @@ -25,9 +25,7 @@ namespace ops { namespace { abstract::ShapePtr TileInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto tile_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(tile_prim); - auto prim_name = tile_prim->name(); + auto prim_name = primitive->name(); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x shape", input_args[0]->BuildShape(), prim_name); auto multiples_v = GetValue>(input_args[1]->cast()->BuildValue()); int len_sub = input_shape.size() - multiples_v.size(); diff --git a/mindspore/core/ops/topk.cc b/mindspore/core/ops/topk.cc index d7d1b727d7..c87a761b2d 100644 --- a/mindspore/core/ops/topk.cc +++ b/mindspore/core/ops/topk.cc @@ -31,9 +31,7 @@ bool TopK::get_sorted() const { AbstractBasePtr TopKInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto top_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(top_prim); - auto prim_name = top_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("top_k_infer", input_args.size(), kEqual, 2, prim_name); // Infer dtype diff --git a/mindspore/core/ops/unpack.cc b/mindspore/core/ops/unpack.cc index 76db999095..2c9ba02b5f 100644 --- a/mindspore/core/ops/unpack.cc +++ b/mindspore/core/ops/unpack.cc @@ -20,22 +20,17 @@ namespace mindspore { namespace ops { void Unpack::Init(const int64_t axis) { this->set_axis(axis); } void Unpack::set_axis(const int64_t axis) { AddAttr(kAxis, MakeValue(axis)); } -int64_t Unpack::get_axis() const { - auto value_ptr = this->GetAttr(kAxis); - return GetValue(value_ptr); -} +int64_t Unpack::get_axis() const { return GetValue(GetAttr(kAxis)); } AbstractBasePtr UnpackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto unpack_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(unpack_prim); - auto prim_name = unpack_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckSubClass("x", input_args[0]->BuildType(), {TypeIdToType(kObjectTypeTensorType)}, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); int64_t dim = x_shape.size(); - int64_t axis = unpack_prim->get_axis(); + int64_t axis = GetValue(primitive->GetAttr(kAxis)); // CheckAndConvertUtils::CheckInRange("axis value", axis, kIncludeLeft, {-dim, dim}, prim_name); if (axis < 0) { axis = axis + dim; diff --git a/mindspore/core/ops/unsorted_segment_sum.cc b/mindspore/core/ops/unsorted_segment_sum.cc index fe17df609a..5a1916c2e9 100644 --- a/mindspore/core/ops/unsorted_segment_sum.cc +++ b/mindspore/core/ops/unsorted_segment_sum.cc @@ -28,9 +28,7 @@ namespace ops { AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto unsortedsegmentsum_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(unsortedsegmentsum_prim); - auto prim_name = unsortedsegmentsum_prim->name(); + auto prim_name = primitive->name(); // Infer type auto x_type = input_args[0]->BuildType()->cast()->element(); diff --git a/mindspore/core/ops/unsqueeze.cc b/mindspore/core/ops/unsqueeze.cc index d4337ce152..1737bf672a 100644 --- a/mindspore/core/ops/unsqueeze.cc +++ b/mindspore/core/ops/unsqueeze.cc @@ -25,16 +25,11 @@ void Unsqueeze::Init(const std::vector axis) { this->set_axis(axis); } void Unsqueeze::set_axis(std::vector axis) { this->AddAttr(kAxis, MakeValue(axis)); } -std::vector Unsqueeze::get_axis() const { - auto value_ptr = this->GetAttr(kAxis); - return GetValue>(value_ptr); -} +std::vector Unsqueeze::get_axis() const { return GetValue>(GetAttr(kAxis)); } AbstractBasePtr UnsqueezeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto unsqueeze_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(unsqueeze_prim); - auto prim_name = unsqueeze_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("unsqueeze_infer", input_args.size(), kEqual, 1, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); auto input = input_args[0]; @@ -43,7 +38,7 @@ AbstractBasePtr UnsqueezeInfer(const abstract::AnalysisEnginePtr &, const Primit auto input_type = input->BuildType()->cast()->element(); // Infer shape - auto dims = unsqueeze_prim->get_axis(); + auto dims = GetValue>(primitive->GetAttr(kAxis)); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input", input->BuildShape(), prim_name); auto input_rank = input_shape.size(); auto dim_rank = dims.size(); diff --git a/mindspore/core/ops/unstack.cc b/mindspore/core/ops/unstack.cc index 371986dde3..3186aabde3 100644 --- a/mindspore/core/ops/unstack.cc +++ b/mindspore/core/ops/unstack.cc @@ -21,19 +21,14 @@ namespace ops { void Unstack::Init(const int64_t axis) { this->set_axis(axis); } void Unstack::set_axis(const int64_t axis) { AddAttr(kAxis, MakeValue(axis)); } -int64_t Unstack::get_axis() const { - auto value_ptr = this->GetAttr(kAxis); - return GetValue(value_ptr); -} +int64_t Unstack::get_axis() const { return GetValue(GetAttr(kAxis)); } AbstractBasePtr UnstackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto unstack_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(unstack_prim); - auto prim_name = unstack_prim->name(); + auto prim_name = primitive->name(); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); int64_t dim = x_shape.size(); - int64_t axis = unstack_prim->get_axis(); + int64_t axis = GetValue(primitive->GetAttr(kAxis)); // CheckAndConvertUtils::CheckInRange("axis value", axis, kIncludeLeft, {-dim, dim}, prim_name); if (axis < 0) { axis = axis + dim; diff --git a/mindspore/core/ops/where.cc b/mindspore/core/ops/where.cc index f583d5751a..27c11c6d3e 100644 --- a/mindspore/core/ops/where.cc +++ b/mindspore/core/ops/where.cc @@ -25,12 +25,10 @@ namespace ops { AbstractBasePtr WhereInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto Where_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(Where_prim); for (auto input : input_args) { MS_EXCEPTION_IF_NULL(input); } - auto op_name = Where_prim->name(); + auto op_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 3, op_name); auto input0_type_ = input_args[0]->BuildType()->cast(); MS_EXCEPTION_IF_NULL(input0_type_); diff --git a/mindspore/core/ops/zeros_like.cc b/mindspore/core/ops/zeros_like.cc index 10009e0dc0..63adb63f49 100644 --- a/mindspore/core/ops/zeros_like.cc +++ b/mindspore/core/ops/zeros_like.cc @@ -30,9 +30,7 @@ namespace ops { namespace { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto zeroslike_prim = primitive->cast(); - MS_EXCEPTION_IF_NULL(zeroslike_prim); - auto prim_name = zeroslike_prim->name(); + auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item);