Browse Source

!15283 remove cast primitive when doing cast

From: @lianliguang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
pull/15283/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
7951d3f8cb
100 changed files with 272 additions and 634 deletions
  1. +0
    -20
      mindspore/core/abstract/infer_functions.h
  2. +0
    -17
      mindspore/core/abstract/prim_nn.cc
  3. +1
    -3
      mindspore/core/ops/abs.cc
  4. +1
    -3
      mindspore/core/ops/adam.cc
  5. +1
    -3
      mindspore/core/ops/add.cc
  6. +1
    -3
      mindspore/core/ops/arg_max.cc
  7. +3
    -8
      mindspore/core/ops/arg_min.cc
  8. +1
    -3
      mindspore/core/ops/asin.cc
  9. +1
    -3
      mindspore/core/ops/assert.cc
  10. +1
    -3
      mindspore/core/ops/assign_add.cc
  11. +1
    -3
      mindspore/core/ops/atan.cc
  12. +12
    -12
      mindspore/core/ops/audio_spectrogram.cc
  13. +2
    -2
      mindspore/core/ops/audio_spectrogram.h
  14. +11
    -25
      mindspore/core/ops/avg_pool.cc
  15. +7
    -7
      mindspore/core/ops/batch_norm.cc
  16. +1
    -3
      mindspore/core/ops/batch_norm_fold.cc
  17. +3
    -5
      mindspore/core/ops/batch_to_space.cc
  18. +3
    -5
      mindspore/core/ops/batch_to_space_nd.cc
  19. +3
    -5
      mindspore/core/ops/binary_cross_entropy.cc
  20. +1
    -3
      mindspore/core/ops/broadcast.cc
  21. +2
    -4
      mindspore/core/ops/concat.cc
  22. +1
    -3
      mindspore/core/ops/constant_of_shape.cc
  23. +1
    -3
      mindspore/core/ops/conv2d_transpose.cc
  24. +1
    -3
      mindspore/core/ops/cos.cc
  25. +1
    -3
      mindspore/core/ops/crop.cc
  26. +1
    -3
      mindspore/core/ops/custom_extract_features.cc
  27. +0
    -5
      mindspore/core/ops/custom_normalize.cc
  28. +1
    -3
      mindspore/core/ops/custom_predict.cc
  29. +7
    -14
      mindspore/core/ops/depth_to_space.cc
  30. +25
    -42
      mindspore/core/ops/depthwise_conv2d.cc
  31. +11
    -22
      mindspore/core/ops/detection_post_process.cc
  32. +1
    -3
      mindspore/core/ops/div.cc
  33. +1
    -3
      mindspore/core/ops/dropout.cc
  34. +1
    -3
      mindspore/core/ops/elu.cc
  35. +1
    -3
      mindspore/core/ops/equal.cc
  36. +1
    -3
      mindspore/core/ops/expand_dims.cc
  37. +1
    -3
      mindspore/core/ops/fake_quant_with_min_max_vars.cc
  38. +1
    -3
      mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc
  39. +1
    -3
      mindspore/core/ops/fft_imag.cc
  40. +1
    -3
      mindspore/core/ops/flatten.cc
  41. +1
    -3
      mindspore/core/ops/floor.cc
  42. +1
    -3
      mindspore/core/ops/fusion/add_fusion.cc
  43. +7
    -8
      mindspore/core/ops/fusion/avg_pool_fusion.cc
  44. +12
    -21
      mindspore/core/ops/fusion/full_connection.cc
  45. +7
    -8
      mindspore/core/ops/fusion/max_pool_fusion.cc
  46. +1
    -3
      mindspore/core/ops/fusion/pow_fusion.cc
  47. +1
    -3
      mindspore/core/ops/fusion/slice_fusion.cc
  48. +1
    -3
      mindspore/core/ops/gather_nd.cc
  49. +1
    -3
      mindspore/core/ops/gelu.cc
  50. +0
    -2
      mindspore/core/ops/grad/avg_pool_grad.cc
  51. +1
    -3
      mindspore/core/ops/grad/batch_norm_grad.cc
  52. +1
    -3
      mindspore/core/ops/grad/bias_add_grad.cc
  53. +1
    -3
      mindspore/core/ops/grad/binary_cross_entropy_grad.cc
  54. +2
    -6
      mindspore/core/ops/grad/dropout_grad.cc
  55. +3
    -6
      mindspore/core/ops/grad/group_conv2d_grad_input.cc
  56. +2
    -3
      mindspore/core/ops/grad/max_pool_grad.cc
  57. +1
    -3
      mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc
  58. +1
    -3
      mindspore/core/ops/grad/smooth_l1_loss_grad.cc
  59. +1
    -3
      mindspore/core/ops/hashtable_lookup.cc
  60. +3
    -8
      mindspore/core/ops/l2_normalize.cc
  61. +1
    -3
      mindspore/core/ops/less.cc
  62. +1
    -3
      mindspore/core/ops/less_equal.cc
  63. +1
    -3
      mindspore/core/ops/logical_and.cc
  64. +2
    -6
      mindspore/core/ops/logical_not.cc
  65. +1
    -3
      mindspore/core/ops/logical_or.cc
  66. +1
    -3
      mindspore/core/ops/lrn.cc
  67. +3
    -11
      mindspore/core/ops/lsh_projection.cc
  68. +21
    -35
      mindspore/core/ops/lstm.cc
  69. +1
    -3
      mindspore/core/ops/matrix_diag.cc
  70. +4
    -16
      mindspore/core/ops/max_pool.cc
  71. +1
    -3
      mindspore/core/ops/maximum.cc
  72. +1
    -3
      mindspore/core/ops/merge.cc
  73. +4
    -8
      mindspore/core/ops/mfcc.cc
  74. +0
    -3
      mindspore/core/ops/non_max_suppression.cc
  75. +4
    -11
      mindspore/core/ops/one_hot.cc
  76. +1
    -3
      mindspore/core/ops/ones_like.cc
  77. +3
    -8
      mindspore/core/ops/pack.cc
  78. +3
    -6
      mindspore/core/ops/pad.cc
  79. +1
    -3
      mindspore/core/ops/pow.cc
  80. +8
    -19
      mindspore/core/ops/prior_box.cc
  81. +5
    -10
      mindspore/core/ops/quant_dtype_cast.cc
  82. +5
    -11
      mindspore/core/ops/range.cc
  83. +1
    -3
      mindspore/core/ops/rank.cc
  84. +1
    -3
      mindspore/core/ops/reciprocal.cc
  85. +3
    -8
      mindspore/core/ops/reduce.cc
  86. +3
    -8
      mindspore/core/ops/resize_bilinear.cc
  87. +4
    -9
      mindspore/core/ops/reverse_sequence.cc
  88. +1
    -3
      mindspore/core/ops/reverse_v2.cc
  89. +3
    -8
      mindspore/core/ops/rfft.cc
  90. +4
    -9
      mindspore/core/ops/roi_pooling.cc
  91. +1
    -3
      mindspore/core/ops/rsqrt.cc
  92. +1
    -3
      mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc
  93. +1
    -3
      mindspore/core/ops/skip_gram.cc
  94. +1
    -3
      mindspore/core/ops/smooth_l1_loss.cc
  95. +1
    -3
      mindspore/core/ops/softmax_cross_entropy_with_logits.cc
  96. +4
    -7
      mindspore/core/ops/space_to_batch.cc
  97. +4
    -7
      mindspore/core/ops/space_to_batch_nd.cc
  98. +3
    -8
      mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc
  99. +1
    -3
      mindspore/core/ops/sparse_to_dense.cc
  100. +1
    -3
      mindspore/core/ops/squared_difference.cc

+ 0
- 20
mindspore/core/abstract/infer_functions.h View File

@@ -53,22 +53,10 @@ AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFastGeLU(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFastGeLUGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@@ -153,10 +141,6 @@ AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

@@ -174,8 +158,6 @@ AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@@ -302,8 +284,6 @@ AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &pri
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 0
- 17
mindspore/core/abstract/prim_nn.cc View File

@@ -140,23 +140,6 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
}
}

AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
auto is_grad = GetValue<bool>(primitive->GetAttr("is_grad"));
CheckArgsSize(primitive->name(), args_spec_list, 2);
std::shared_ptr<BaseShape> shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{});
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
if (is_grad) {
shape = args_spec_list[0]->BuildShape();
}
auto type = args_spec_list[0]->BuildType();
MS_EXCEPTION_IF_NULL(type);
auto type_tensor = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(type_tensor);
return std::make_shared<abstract::AbstractTensor>(type_tensor->element(), shape);
}

AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: five tensors(x, gamma, beta, mean, variance).


+ 1
- 3
mindspore/core/ops/abs.cc View File

@@ -30,9 +30,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto abs_prim = primitive->cast<PrimAbsPtr>();
MS_EXCEPTION_IF_NULL(abs_prim);
auto prim_name = abs_prim->name();
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


+ 1
- 3
mindspore/core/ops/adam.cc View File

@@ -23,9 +23,7 @@ namespace ops {
namespace {
abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto Adam_prim = primitive->cast<PrimAdamPtr>();
MS_EXCEPTION_IF_NULL(Adam_prim);
auto prim_name = Adam_prim->name();
auto prim_name = primitive->name();

// infer shape
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name);


+ 1
- 3
mindspore/core/ops/add.cc View File

@@ -27,9 +27,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto add_prim = primitive->cast<PrimAddPtr>();
MS_EXCEPTION_IF_NULL(add_prim);
auto prim_name = add_prim->name();
auto prim_name = primitive->name();
return BroadCastInferShape(prim_name, input_args);
}



+ 1
- 3
mindspore/core/ops/arg_max.cc View File

@@ -22,9 +22,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto prim = primitive->cast<PrimArgMaxPtr>();
MS_EXCEPTION_IF_NULL(prim);
auto axis = prim->get_axis();
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);


+ 3
- 8
mindspore/core/ops/arg_min.cc View File

@@ -27,10 +27,7 @@ void ArgMin::Init(const int64_t axis, const TypeId output_type) {
void ArgMin::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
void ArgMin::set_output_type(const TypeId output_type) { this->AddAttr(kOutputType, TypeIdToType(output_type)); }

int64_t ArgMin::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
int64_t ArgMin::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }

TypeId ArgMin::get_output_type() const {
auto type_ptr = GetAttr(kOutputType)->cast<TensorTypePtr>()->element();
@@ -40,13 +37,11 @@ TypeId ArgMin::get_output_type() const {
AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto argmin_prim = primitive->cast<PrimArgMin>();
MS_EXCEPTION_IF_NULL(argmin_prim);
auto prim_name = argmin_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("arg_min_infer", input_args.size(), kEqual, 1, prim_name);

// Infer shape
auto axis = argmin_prim->get_axis();
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);


+ 1
- 3
mindspore/core/ops/asin.cc View File

@@ -25,9 +25,7 @@ namespace ops {
AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto asin_prim = primitive->cast<PrimAsinPtr>();
MS_EXCEPTION_IF_NULL(asin_prim);
auto prim_name = asin_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name);

// Infer Shape


+ 1
- 3
mindspore/core/ops/assert.cc View File

@@ -37,9 +37,7 @@ int64_t Assert::get_summarize() const {
AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto Assert_prim = primitive->cast<PrimAssertPtr>();
MS_EXCEPTION_IF_NULL(Assert_prim);
auto op_name = Assert_prim->name();
auto op_name = primitive->name();
TypePtr condition;
if (!(input_args[0]->BuildType()->type_id() == kObjectTypeTensorType)) {
auto condition_value = GetValue<std::vector<bool>>(input_args[0]->BuildValue());


+ 1
- 3
mindspore/core/ops/assign_add.cc View File

@@ -25,9 +25,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto assignadd_prim = primitive->cast<PrimAssignAddPtr>();
MS_EXCEPTION_IF_NULL(assignadd_prim);
auto prim_name = assignadd_prim->name();
auto prim_name = primitive->name();
auto value_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name);
return std::make_shared<abstract::Shape>(value_shape);


+ 1
- 3
mindspore/core/ops/atan.cc View File

@@ -23,9 +23,7 @@ namespace ops {
AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto atan_prim = primitive->cast<PrimAtanPtr>();
MS_EXCEPTION_IF_NULL(atan_prim);
auto prim_name = atan_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);

// Infer Shape


+ 12
- 12
mindspore/core/ops/audio_spectrogram.cc View File

@@ -30,25 +30,25 @@ namespace {
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto audio_spectrogram_prim = primitive->cast<PrimAudioSpectrogramPtr>();
MS_EXCEPTION_IF_NULL(audio_spectrogram_prim);
auto prim_name = audio_spectrogram_prim->name();
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
}
if (audio_spectrogram_prim->get_window_size() < 2) {
MS_LOG(ERROR) << "window size is too short, now is " << audio_spectrogram_prim->get_window_size();
auto window_size = GetValue<int64_t>(primitive->GetAttr(kWindowSize));
if (window_size < 2) {
MS_LOG(ERROR) << "window size is too short, now is " << window_size;
}
if (audio_spectrogram_prim->get_stride() < 1) {
MS_LOG(ERROR) << "stride must be positive, now is " << audio_spectrogram_prim->get_stride();
auto stride_size = GetValue<int64_t>(primitive->GetAttr(kStride));
if (stride_size < 1) {
MS_LOG(ERROR) << "stride must be positive, now is " << stride_size;
}
std::vector<int64_t> infer_shape;
infer_shape.push_back(input_shape[1]);
int64_t sample_sub_window = input_shape[0] - audio_spectrogram_prim->get_window_size();
infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / audio_spectrogram_prim->get_stride());
int64_t fft_length = audio_spectrogram_prim->GetFftLength(audio_spectrogram_prim->get_window_size());
int64_t sample_sub_window = input_shape[0] - window_size;
infer_shape.push_back(sample_sub_window < 0 ? 0 : 1 + sample_sub_window / stride_size);
int64_t fft_length = GetFftLength(window_size);
infer_shape.push_back(fft_length / 2 + 1);
MS_LOG(ERROR) << infer_shape;
return std::make_shared<abstract::Shape>(infer_shape);
@@ -81,7 +81,7 @@ int64_t AudioSpectrogram::get_stride() const {
return GetValue<int64_t>(value_ptr);
}

int64_t AudioSpectrogram::Log2Ceil(int64_t length) {
int64_t Log2Ceil(int64_t length) {
if (length == 0) {
return -1;
}
@@ -97,7 +97,7 @@ int64_t AudioSpectrogram::Log2Ceil(int64_t length) {
return length == (length & ~(unsigned int)(length - 1)) ? floor : floor + 1;
}

int64_t AudioSpectrogram::GetFftLength(int64_t length) {
int64_t GetFftLength(int64_t length) {
int64_t shift = Log2Ceil(length);
return 1 << (unsigned int)shift;
}


+ 2
- 2
mindspore/core/ops/audio_spectrogram.h View File

@@ -27,6 +27,8 @@
namespace mindspore {
namespace ops {
constexpr auto kNameAudioSpectrogram = "AudioSpectrogram";
int64_t Log2Ceil(int64_t length);
int64_t GetFftLength(int64_t length);
class AudioSpectrogram : public PrimitiveC {
public:
AudioSpectrogram() : PrimitiveC(kNameAudioSpectrogram) {}
@@ -39,8 +41,6 @@ class AudioSpectrogram : public PrimitiveC {
int64_t get_window_size() const;
int64_t get_stride() const;
bool get_mag_square() const;
int64_t Log2Ceil(int64_t length);
int64_t GetFftLength(int64_t length);
};
AbstractBasePtr AudioSpectrogramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);


+ 11
- 25
mindspore/core/ops/avg_pool.cc View File

@@ -31,37 +31,25 @@ void AvgPool::set_pad_mode(const PadMode &pad_mode) {
this->AddAttr(kPadMode, MakeValue(swi));
}

PadMode AvgPool::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
void AvgPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
}

std::vector<int64_t> AvgPool::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> AvgPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
void AvgPool::set_strides(const std::vector<int64_t> &strides) {
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
}

std::vector<int64_t> AvgPool::get_strides() const {
auto value_ptr = GetAttr(kStrides);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> AvgPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }

void AvgPool::set_format(const Format &format) {
int64_t f = format;
this->AddAttr(kFormat, MakeValue(f));
}

Format AvgPool::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
Format AvgPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }

void AvgPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); }

@@ -93,22 +81,20 @@ void AvgPool::Init(const std::vector<int64_t> &kernel_size, const std::vector<in
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto pool_prim = primitive->cast<PrimAvgPoolPtr>();
MS_EXCEPTION_IF_NULL(pool_prim);
auto op_name = pool_prim->name();
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
if (pool_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
auto kernel_size = pool_prim->get_kernel_size();
auto pad_mode = pool_prim->get_pad_mode();
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto batch = in_shape[0];
auto channel = in_shape[1];
auto in_h = in_shape[2];
auto in_w = in_shape[3];

auto strides = pool_prim->get_strides();
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
auto kernel_h = kernel_size[2];
auto kernel_w = kernel_size[3];
auto stride_h = strides[2];
@@ -123,7 +109,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
out_w = ceil(in_w / stride_w);
}
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
if (pool_prim->get_format() == NHWC) {
if (format == NHWC) {
out_shape = {batch, out_h, out_w, channel};
}
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {


+ 7
- 7
mindspore/core/ops/batch_norm.cc View File

@@ -72,13 +72,12 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
const std::vector<AbstractBasePtr> &input_args) {
// Infer shape
MS_EXCEPTION_IF_NULL(primitive);
auto batch_prim = primitive->cast<PrimBatchNormPtr>();
MS_EXCEPTION_IF_NULL(batch_prim);
auto prim_name = batch_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name);

auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
if (batch_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
}
auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name);
@@ -87,7 +86,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name);

std::vector<int64_t> input_shape_norm;
if (batch_prim->get_format() == NCHW) {
if (format == NCHW) {
input_shape_norm =
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
} else {
@@ -100,7 +99,8 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
TypeError);
if (!batch_prim->get_is_training()) {

if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) {
CheckAndConvertUtils::CheckInteger("mean rank", mean.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
@@ -126,7 +126,7 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
if (batch_prim->get_format() == NHWC) {
if (format == NHWC) {
output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);


+ 1
- 3
mindspore/core/ops/batch_norm_fold.cc View File

@@ -67,9 +67,7 @@ int64_t BatchNormFold::get_freeze_bn() const {
AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto BatchNormFold_prim = primitive->cast<PrimBatchNormFoldPtr>();
MS_EXCEPTION_IF_NULL(BatchNormFold_prim);
auto op_name = BatchNormFold_prim->name();
auto op_name = primitive->name();
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name);
auto variance_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name);


+ 3
- 5
mindspore/core/ops/batch_to_space.cc View File

@@ -47,9 +47,7 @@ std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const {
AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim = primitive->cast<PrimBatchToSpacePtr>();
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
@@ -59,8 +57,8 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri

auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
auto block_size = prim->get_block_size();
auto crops = prim->get_crops();
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
auto out_shape = x_shape;
for (size_t i = 0; i < 2; ++i) {
auto x_block_prod = out_shape[i + 2] * block_size[i];


+ 3
- 5
mindspore/core/ops/batch_to_space_nd.cc View File

@@ -28,16 +28,14 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto batch_prim = primitive->cast<PrimBatchToSpaceNDPtr>();
MS_EXCEPTION_IF_NULL(batch_prim);
auto prim_name = batch_prim->name();
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;
int64_t offset = 2;
auto block_shape = batch_prim->get_block_shape();
auto crops = batch_prim->get_crops();
auto block_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockShape));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
int64_t size = block_shape.size();
for (int64_t i = 0; i < size; i++) {
block_shape_prod = block_shape_prod * block_shape[i];


+ 3
- 5
mindspore/core/ops/binary_cross_entropy.cc View File

@@ -32,9 +32,7 @@ namespace {
abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto binary_cross_entropy_prim = primitive->cast<PrimBinaryCrossEntropyPtr>();
MS_EXCEPTION_IF_NULL(binary_cross_entropy_prim);
auto prim_name = binary_cross_entropy_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
@@ -45,8 +43,8 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
if (weight_shape.size() < 1) {
CheckAndConvertUtils::Check("x shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);
}
if (binary_cross_entropy_prim->get_reduction() != REDUCTION_SUM &&
binary_cross_entropy_prim->get_reduction() != MEAN) {
auto reduction = Reduction(GetValue<int64_t>(primitive->GetAttr(kReduction)));
if (reduction != REDUCTION_SUM && reduction != MEAN) {
infer_shape = {x_shape.begin(), infer_shape.end()};
}
return std::make_shared<abstract::Shape>(infer_shape);


+ 1
- 3
mindspore/core/ops/broadcast.cc View File

@@ -45,9 +45,7 @@ std::string Broadcast::get_group() const {
AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto broadcast_prim = primitive->cast<PrimBroadcast>();
MS_EXCEPTION_IF_NULL(broadcast_prim);
auto prim_name = broadcast_prim->name();
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


+ 2
- 4
mindspore/core/ops/concat.cc View File

@@ -32,9 +32,7 @@ void Concat::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)
AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim = primitive->cast<PrimConcatPtr>();
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
@@ -48,7 +46,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
auto element0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
auto element0_rank = SizeToLong(element0_shape.size());
auto axis = prim->get_axis();
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank},
prim_name);
axis = axis < 0 ? axis + element0_rank : axis;


+ 1
- 3
mindspore/core/ops/constant_of_shape.cc View File

@@ -31,9 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A

TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto constant_prim = primitive->cast<PrimConstantOfShapePtr>();
MS_EXCEPTION_IF_NULL(constant_prim);
auto data_type = TypeId(constant_prim->get_data_type());
auto data_type = TypeId(GetValue<int64_t>(primitive->GetAttr(kDataType)));
return TypeIdToType(data_type);
}
} // namespace


+ 1
- 3
mindspore/core/ops/conv2d_transpose.cc View File

@@ -28,9 +28,7 @@ namespace {
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv2d_transpose_prim = primitive->cast<PrimConv2dTransposePtr>();
MS_EXCEPTION_IF_NULL(conv2d_transpose_prim);
auto prim_name = conv2d_transpose_prim->name();
auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name);
return std::make_shared<abstract::Shape>(input_shape);
}


+ 1
- 3
mindspore/core/ops/cos.cc View File

@@ -24,9 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto cos_prim = primitive->cast<PrimCos>();
MS_EXCEPTION_IF_NULL(cos_prim);
auto prim_name = cos_prim->name();
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


+ 1
- 3
mindspore/core/ops/crop.cc View File

@@ -43,9 +43,7 @@ std::vector<int64_t> Crop::get_offsets() const {
AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto crop_prim = primitive->cast<PrimCrop>();
MS_EXCEPTION_IF_NULL(crop_prim);
auto prim_name = crop_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);


+ 1
- 3
mindspore/core/ops/custom_extract_features.cc View File

@@ -24,9 +24,7 @@ namespace ops {
AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto extract_prim = primitive->cast<PrimCustomExtractFeaturesPtr>();
MS_EXCEPTION_IF_NULL(extract_prim);
auto prim_name = extract_prim->name();
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
// auto input = input_args[0];



+ 0
- 5
mindspore/core/ops/custom_normalize.cc View File

@@ -24,13 +24,8 @@ namespace {
abstract::ShapePtr CustomNormalizeInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto custom_normalize_prim = primitive->cast<PrimCustomNormalizePtr>();
MS_EXCEPTION_IF_NULL(custom_normalize_prim);
// auto prim_name = custom_normalize_prim->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[0]->BuildShape());
// auto input_shape =
// CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
if (input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c() == nullptr) {
MS_LOG(ERROR) << "Do infer shape in runtime.";
}


+ 1
- 3
mindspore/core/ops/custom_predict.cc View File

@@ -45,13 +45,11 @@ float CustomPredict::get_weight_threshold() const {
AbstractBasePtr CustomPredictInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto CustomPredict_prim = primitive->cast<PrimCustomPredictPtr>();
MS_EXCEPTION_IF_NULL(CustomPredict_prim);
for (const auto &input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
std::vector<int64_t> shape;
shape.push_back(CustomPredict_prim->get_output_num());
shape.push_back(GetValue<int64_t>(primitive->GetAttr(kOutputNum)));

auto output0 = std::make_shared<abstract::AbstractTensor>(kInt32, shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);


+ 7
- 14
mindspore/core/ops/depth_to_space.cc View File

@@ -30,19 +30,13 @@ void DepthToSpace::set_block_size(const int64_t block_size) {
this->AddAttr(kBlockSize, MakeValue(block_size));
}

int64_t DepthToSpace::get_block_size() const {
auto value_ptr = GetAttr(kBlockSize);
return GetValue<int64_t>(value_ptr);
}
int64_t DepthToSpace::get_block_size() const { return GetValue<int64_t>(GetAttr(kBlockSize)); }
void DepthToSpace::set_format(const Format &format) {
int64_t f = format;
this->AddAttr(kFormat, MakeValue(f));
}

Format DepthToSpace::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
Format DepthToSpace::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }

void DepthToSpace::Init(const int64_t block_size, const Format &format) {
this->set_block_size(block_size);
@@ -52,9 +46,7 @@ void DepthToSpace::Init(const int64_t block_size, const Format &format) {
AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim = primitive->cast<PrimDepthToSpacePtr>();
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
@@ -63,18 +55,19 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
MS_EXCEPTION_IF_NULL(input_x);

auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
if (prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
}
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
int64_t block_size = prim->get_block_size();
int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size),
kEqual, 0, prim_name);
auto out_shape = x_shape;
out_shape[1] /= block_size * block_size;
out_shape[2] *= block_size;
out_shape[3] *= block_size;
if (prim->get_format() == NHWC) {
if (format == NHWC) {
out_shape = {out_shape[0], out_shape[2], out_shape[3], out_shape[1]};
}
auto ret = input_x->Broaden();


+ 25
- 42
mindspore/core/ops/depthwise_conv2d.cc View File

@@ -65,25 +65,14 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vector<i
}

std::vector<int64_t> DepthWiseConv2D::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> DepthWiseConv2D::get_stride() const {
auto value_ptr = GetAttr(kStride);
return GetValue<std::vector<int64_t>>(value_ptr);
return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize));
}
std::vector<int64_t> DepthWiseConv2D::get_stride() const { return GetValue<std::vector<int64_t>>(GetAttr(kStride)); }
std::vector<int64_t> DepthWiseConv2D::get_dilation() const {
auto value_ptr = GetAttr(kDilation);
return GetValue<std::vector<int64_t>>(value_ptr);
}
PadMode DepthWiseConv2D::get_pad_mode() const {
auto value_ptr = this->GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
std::vector<int64_t> DepthWiseConv2D::get_pad() const {
auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int64_t>>(value_ptr);
return GetValue<std::vector<int64_t>>(GetAttr(kDilation));
}
PadMode DepthWiseConv2D::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
std::vector<int64_t> DepthWiseConv2D::get_pad() const { return GetValue<std::vector<int64_t>>(GetAttr(kPad)); }

std::vector<int64_t> DepthWiseConv2D::get_pads() const {
auto value_ptr = this->GetAttr(kPads);
@@ -99,10 +88,7 @@ int64_t DepthWiseConv2D::get_group() const {
auto value_ptr = this->GetAttr(kGroup);
return GetValue<int64_t>(value_ptr);
}
int64_t DepthWiseConv2D::get_out_channel() const {
auto value_ptr = this->GetAttr(kOutChannel);
return GetValue<int64_t>(value_ptr);
}
int64_t DepthWiseConv2D::get_out_channel() const { return GetValue<int64_t>(GetAttr(kOutChannel)); }

void DepthWiseConv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize, MakeValue(kernel_size));
@@ -126,33 +112,29 @@ void DepthWiseConv2D::set_format(const Format &format) {
this->AddAttr(kFormat, MakeValue(f));
}

Format DepthWiseConv2D::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
Format DepthWiseConv2D::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }

abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = primitive->cast<PrimDepthWiseConv2DPtr>();
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
if (conv_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
}
CheckAndConvertUtils::CheckInteger("weight_rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x_rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], conv_prim->name());
auto out_channel = conv_prim->get_out_channel();
CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], prim_name);
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));

std::vector<int64_t> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w,
conv_prim->name());
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
"w_shape[2:4]", temp_w, prim_name);

auto kernel_size_n = w_shape[0];
if (kernel_size_n != 1) {
@@ -160,8 +142,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
}
auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3];
auto stride = conv_prim->get_stride();
auto dilation = conv_prim->get_dilation();
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kDilation));
auto stride_h = stride[2];
auto stride_w = stride[3];
auto dilation_h = dilation[2];
@@ -169,7 +151,7 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
int64_t h_out = -1;
int64_t w_out = -1;
std::vector<int64_t> pad_list(4, 0);
auto pad_mode = conv_prim->get_pad_mode();
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
if (pad_mode == VALID) {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
@@ -187,20 +169,21 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == PAD) {
std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->get_pad()[0];
auto pad_bottom = conv_prim->get_pad()[1];
auto pad_right = conv_prim->get_pad()[2];
auto pad_left = conv_prim->get_pad()[3];
auto pads = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
std::copy(pads.begin(), pads.end(), std::back_inserter(pad_list));
auto pad_top = pads[0];
auto pad_bottom = pads[1];
auto pad_right = pads[2];
auto pad_left = pads[3];

h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
h_out = floor(h_out);
w_out = floor(w_out);
}
conv_prim->set_pads(pad_list);
primitive->AddAttr(kPads, MakeValue(pad_list));
std::vector<int64_t> out_shape = {x_shape[0], out_channel * x_shape[1], h_out, w_out};
if (conv_prim->get_format() == NHWC) {
if (format == NHWC) {
out_shape = {x_shape[0], h_out, w_out, out_channel * x_shape[1]};
}
return std::make_shared<abstract::Shape>(out_shape);


+ 11
- 22
mindspore/core/ops/detection_post_process.cc View File

@@ -68,10 +68,7 @@ float DetectionPostProcess::get_nms_score_threshold() const {
void DetectionPostProcess::set_max_detections(const int64_t MaxDetections) {
this->AddAttr(kMaxDetections, MakeValue(MaxDetections));
}
int64_t DetectionPostProcess::get_max_detections() const {
auto value_ptr = this->GetAttr(kMaxDetections);
return GetValue<int64_t>(value_ptr);
}
int64_t DetectionPostProcess::get_max_detections() const { return GetValue<int64_t>(GetAttr(kMaxDetections)); }

void DetectionPostProcess::set_detections_per_class(const int64_t DetectionsPerClass) {
this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass));
@@ -85,17 +82,13 @@ void DetectionPostProcess::set_max_classes_per_detection(const int64_t MaxClasse
this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection));
}
int64_t DetectionPostProcess::get_max_classes_per_detection() const {
auto value_ptr = this->GetAttr(kMaxClassesPerDetection);
return GetValue<int64_t>(value_ptr);
return GetValue<int64_t>(GetAttr(kMaxClassesPerDetection));
}

void DetectionPostProcess::set_num_classes(const int64_t NumClasses) {
this->AddAttr(kNumClasses, MakeValue(NumClasses));
}
int64_t DetectionPostProcess::get_num_classes() const {
auto value_ptr = this->GetAttr(kNumClasses);
return GetValue<int64_t>(value_ptr);
}
int64_t DetectionPostProcess::get_num_classes() const { return GetValue<int64_t>(GetAttr(kNumClasses)); }
void DetectionPostProcess::set_use_regular_nms(const bool UseRegularNms) {
this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms));
}
@@ -115,16 +108,11 @@ void DetectionPostProcess::set_format(const Format &format) {
int64_t f = format;
this->AddAttr(kFormat, MakeValue(f));
}
Format DetectionPostProcess::get_format() const {
auto value_ptr = this->GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
Format DetectionPostProcess::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto detection_prim = primitive->cast<PrimDetectionPostProcessPtr>();
MS_EXCEPTION_IF_NULL(detection_prim);
auto prim_name = detection_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("detection_post_process_infer", input_args.size(), kEqual, 3, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);
@@ -135,12 +123,13 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name);
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name);
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name);
if (detection_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]};
scores_shape = {scores_shape[0], scores_shape[3], scores_shape[1], scores_shape[2]};
anchors_shape = {anchors_shape[0], anchors_shape[3], anchors_shape[1], anchors_shape[2]};
}
auto num_classes = detection_prim->get_num_classes();
auto num_classes = GetValue<int64_t>(primitive->GetAttr(kNumClasses));
CheckAndConvertUtils::CheckInRange("scores_shape[2]", scores_shape[2], kIncludeBoth, {num_classes, num_classes + 1},
prim_name);
CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "scores_shape[1]", scores_shape[1], prim_name,
@@ -149,8 +138,8 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
ValueError);

// Infer shape
auto max_detections = detection_prim->get_max_detections();
auto max_classes_per_detection = detection_prim->get_max_classes_per_detection();
auto max_detections = GetValue<int64_t>(primitive->GetAttr(kMaxDetections));
auto max_classes_per_detection = GetValue<int64_t>(primitive->GetAttr(kMaxClassesPerDetection));
auto num_detected_boxes = max_detections * max_classes_per_detection;
std::vector<int64_t> output_boxes_shape = {1, num_detected_boxes, 4};
std::vector<int64_t> output_class_shape = {1, num_detected_boxes};
@@ -163,7 +152,7 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
auto output1 = std::make_shared<abstract::AbstractTensor>(output_type, output_class_shape);
auto output2 = std::make_shared<abstract::AbstractTensor>(output_type, output_num_shape);
AbstractBasePtrList output = {output0, output1, output1, output2};
if (detection_prim->get_format() == NHWC) {
if (format == NHWC) {
output = {output0, output1, output2, output1};
}
return std::make_shared<abstract::AbstractTuple>(output);


+ 1
- 3
mindspore/core/ops/div.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto div_prim = primitive->cast<PrimDivPtr>();
MS_EXCEPTION_IF_NULL(div_prim);
auto prim_name = div_prim->name();
auto prim_name = primitive->name();
return BroadCastInferShape(prim_name, input_args);
}



+ 1
- 3
mindspore/core/ops/dropout.cc View File

@@ -39,9 +39,7 @@ float Dropout::get_keep_prob() const {
AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto dropout_prim = primitive->cast<PrimDropoutPtr>();
MS_EXCEPTION_IF_NULL(dropout_prim);
auto prim_name = dropout_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name);

// Infer shape


+ 1
- 3
mindspore/core/ops/elu.cc View File

@@ -31,9 +31,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto elu_prim = primitive->cast<PrimElu>();
MS_EXCEPTION_IF_NULL(elu_prim);
auto op_name = elu_prim->name();
auto op_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


+ 1
- 3
mindspore/core/ops/equal.cc View File

@@ -29,9 +29,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto equal_prim = primitive->cast<PrimEqualPtr>();
MS_EXCEPTION_IF_NULL(equal_prim);
auto op_name = equal_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 1
- 3
mindspore/core/ops/expand_dims.cc View File

@@ -30,9 +30,7 @@ namespace ops {
AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto expand_dims_prim = primitive->cast<PrimExpandDims>();
MS_EXCEPTION_IF_NULL(expand_dims_prim);
auto prim_name = expand_dims_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);


+ 1
- 3
mindspore/core/ops/fake_quant_with_min_max_vars.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto fake_prim = primitive->cast<PrimFakeQuantWithMinMaxVarsPtr>();
MS_EXCEPTION_IF_NULL(fake_prim);
auto prim_name = fake_prim->name();
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name);
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name);


+ 1
- 3
mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc View File

@@ -43,9 +43,7 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto FakeQuantWithMinMaxVarsPerChannel_prim = primitive->cast<PrimFakeQuantWithMinMaxVarsPerChannelPtr>();
MS_EXCEPTION_IF_NULL(FakeQuantWithMinMaxVarsPerChannel_prim);
auto op_name = FakeQuantWithMinMaxVarsPerChannel_prim->name();
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name);
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name);


+ 1
- 3
mindspore/core/ops/fft_imag.cc View File

@@ -24,9 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto FftImag_prim = primitive->cast<PrimFftImagPtr>();
MS_EXCEPTION_IF_NULL(FftImag_prim);
auto prim_name = FftImag_prim->name();
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name);
in_shape.pop_back();
return std::make_shared<abstract::Shape>(in_shape);


+ 1
- 3
mindspore/core/ops/flatten.cc View File

@@ -23,9 +23,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto flatten_prim = primitive->cast<PrimFlattenPtr>();
MS_EXCEPTION_IF_NULL(flatten_prim);
auto prim_name = flatten_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto prod = 1;


+ 1
- 3
mindspore/core/ops/floor.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto floor_prim = primitive->cast<PrimFLoorPtr>();
MS_EXCEPTION_IF_NULL(floor_prim);
auto prim_name = floor_prim->name();
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


+ 1
- 3
mindspore/core/ops/fusion/add_fusion.cc View File

@@ -39,9 +39,7 @@ void AddFusion::Init(const ActivationType activation_type) { this->set_activatio
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto add_prim = primitive->cast<PrimAddFusionPtr>();
MS_EXCEPTION_IF_NULL(add_prim);
auto op_name = add_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 7
- 8
mindspore/core/ops/fusion/avg_pool_fusion.cc View File

@@ -52,22 +52,21 @@ ActivationType AvgPoolFusion::get_activation_type() const {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto pool_prim = primitive->cast<PrimAvgPoolFusionPtr>();
MS_EXCEPTION_IF_NULL(pool_prim);
auto op_name = pool_prim->name();
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
if (pool_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
auto kernel_size = pool_prim->get_kernel_size();
auto pad_mode = pool_prim->get_pad_mode();
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto batch = in_shape[0];
auto channel = in_shape[1];
auto in_h = in_shape[2];
auto in_w = in_shape[3];

auto strides = pool_prim->get_strides();
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
auto kernel_h = kernel_size[2];
auto kernel_w = kernel_size[3];
auto stride_h = strides[2];
@@ -82,7 +81,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
out_w = ceil(in_w / stride_w);
}
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
if (pool_prim->get_format() == NHWC) {
if (format == NHWC) {
out_shape = {batch, out_h, out_w, channel};
}
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {


+ 12
- 21
mindspore/core/ops/fusion/full_connection.cc View File

@@ -21,22 +21,13 @@
namespace mindspore {
namespace ops {
void FullConnection::set_has_bias(const bool has_bias) { this->AddAttr(kHasBias, MakeValue(has_bias)); }
bool FullConnection::get_has_bias() const {
auto value_ptr = GetAttr(kHasBias);
return GetValue<bool>(value_ptr);
}
bool FullConnection::get_has_bias() const { return GetValue<bool>(GetAttr(kHasBias)); }

void FullConnection::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }
int64_t FullConnection::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
int64_t FullConnection::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }

void FullConnection::set_use_axis(const bool use_axis) { this->AddAttr(kUseAxis, MakeValue(use_axis)); }
bool FullConnection::get_use_axis() const {
auto value_ptr = GetAttr(kUseAxis);
return GetValue<bool>(value_ptr);
}
bool FullConnection::get_use_axis() const { return GetValue<bool>(GetAttr(kUseAxis)); }

void FullConnection::set_activation_type(const ActivationType &activation_type) {
int64_t swi;
@@ -57,26 +48,26 @@ void FullConnection::Init(const bool has_bias, const int64_t axis, const bool us
AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto full_prim = primitive->cast<PrimFullConnectionPtr>();
MS_EXCEPTION_IF_NULL(full_prim);
auto prim_name = full_prim->name();
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);
auto input0 = input_args[0];
auto input1 = input_args[1];
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name);
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name);
auto prim_axis = full_prim->get_axis();
if (full_prim->get_has_bias()) {
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
if (has_bias) {
CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 3, prim_name);
} else {
CheckAndConvertUtils::CheckInteger("input_args.size()", input_args.size(), kEqual, 2, prim_name);
}
if (full_prim->get_use_axis() && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) {
auto use_axis = GetValue<bool>(primitive->GetAttr(kUseAxis));
if (use_axis && (prim_axis < 1 || prim_axis > (int64_t)input0_shape.size())) {
MS_EXCEPTION(ValueError) << "Full Connection axis invalid";
}
int64_t new_k = 1;
if (full_prim->get_use_axis()) {
if (use_axis) {
for (size_t t = prim_axis; t < input0_shape.size(); t++) {
new_k *= input0_shape[t];
}
@@ -86,7 +77,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
} else {
new_k = input1_shape[1];
}
if (full_prim->get_has_bias()) {
if (has_bias) {
auto input2_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name);
if (input2_shape[0] != input1_shape[0]) {
@@ -94,7 +85,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
}
}
std::vector<int64_t> out_shape = {(int64_t)input0_shape.size()};
if (full_prim->get_use_axis()) {
if (use_axis) {
out_shape.resize(prim_axis + 1);
out_shape[prim_axis] = input1_shape[0];
} else {


+ 7
- 8
mindspore/core/ops/fusion/max_pool_fusion.cc View File

@@ -52,22 +52,21 @@ ActivationType MaxPoolFusion::get_activation_type() const {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto pool_prim = primitive->cast<PrimMaxPoolFusionPtr>();
MS_EXCEPTION_IF_NULL(pool_prim);
auto op_name = pool_prim->name();
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
if (pool_prim->get_format() == NHWC) {
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
auto kernel_size = pool_prim->get_kernel_size();
auto pad_mode = pool_prim->get_pad_mode();
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto batch = in_shape[0];
auto channel = in_shape[1];
auto in_h = in_shape[2];
auto in_w = in_shape[3];

auto strides = pool_prim->get_strides();
auto strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
auto kernel_h = kernel_size[2];
auto kernel_w = kernel_size[3];
auto stride_h = strides[2];
@@ -82,7 +81,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
out_w = ceil(in_w / stride_w);
}
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
if (pool_prim->get_format() == NHWC) {
if (format == NHWC) {
out_shape = {batch, out_h, out_w, channel};
}
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {


+ 1
- 3
mindspore/core/ops/fusion/pow_fusion.cc View File

@@ -37,9 +37,7 @@ float PowFusion::get_shift() const { return GetValue<float>(GetAttr(kShift)); }
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto pow_prim = primitive->cast<PrimPowPtr>();
MS_EXCEPTION_IF_NULL(pow_prim);
auto op_name = pow_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 1
- 3
mindspore/core/ops/fusion/slice_fusion.cc View File

@@ -32,9 +32,7 @@ std::vector<int64_t> SliceFusion::get_axes() const {
AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto SliceFusion_prim = primitive->cast<PrimSliceFusionPtr>();
MS_EXCEPTION_IF_NULL(SliceFusion_prim);
auto op_name = SliceFusion_prim->name();
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto x_shape_len = (int64_t)x_shape.size();
auto begin_v = input_args[1]->BuildValue();


+ 1
- 3
mindspore/core/ops/gather_nd.cc View File

@@ -27,9 +27,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto gather_prim = primitive->cast<PrimGatherNd>();
MS_EXCEPTION_IF_NULL(gather_prim);
auto prim_name = gather_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);


+ 1
- 3
mindspore/core/ops/gelu.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto gelu_prim = primitive->cast<PrimGeLUPtr>();
MS_EXCEPTION_IF_NULL(gelu_prim);
auto prim_name = gelu_prim->name();
auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
return std::make_shared<abstract::Shape>(input_shape);
}


+ 0
- 2
mindspore/core/ops/grad/avg_pool_grad.cc View File

@@ -22,8 +22,6 @@ namespace ops {
AbstractBasePtr AvgPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto AvgPoolGrad_prim = primitive->cast<PrimAvgPoolGradPtr>();
MS_EXCEPTION_IF_NULL(AvgPoolGrad_prim);
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
auto origin_input_shape = GetValue<std::vector<int64_t>>(input_args[0]->BuildValue());
auto tensor_type = input_args[1]->BuildType()->cast<TensorTypePtr>();


+ 1
- 3
mindspore/core/ops/grad/batch_norm_grad.cc View File

@@ -47,9 +47,7 @@ bool BatchNormGrad::get_is_training() const {
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto BatchNormGrad_prim = primitive->cast<PrimBatchNormGradPtr>();
MS_EXCEPTION_IF_NULL(BatchNormGrad_prim);
auto op_name = BatchNormGrad_prim->name();
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[1]);
MS_EXCEPTION_IF_NULL(input_args[2]);
MS_EXCEPTION_IF_NULL(input_args[3]);


+ 1
- 3
mindspore/core/ops/grad/bias_add_grad.cc View File

@@ -41,9 +41,7 @@ Format BiasAddGrad::get_format() const {
AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto bias_prim = primitive->cast<PrimBiasAddGradPtr>();
MS_EXCEPTION_IF_NULL(bias_prim);
auto prim_name = bias_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);



+ 1
- 3
mindspore/core/ops/grad/binary_cross_entropy_grad.cc View File

@@ -26,9 +26,7 @@ namespace {
abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto binary_cross_entropy_grad_prim = primitive->cast<PrimBinaryCrossEntropyGradPtr>();
MS_EXCEPTION_IF_NULL(binary_cross_entropy_grad_prim);
auto prim_name = binary_cross_entropy_grad_prim->name();
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto weight_shape =


+ 2
- 6
mindspore/core/ops/grad/dropout_grad.cc View File

@@ -35,18 +35,14 @@ namespace {
abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto DropoutGrad_prim = primitive->cast<PrimDropoutGradPtr>();
MS_EXCEPTION_IF_NULL(DropoutGrad_prim);
auto op_name = DropoutGrad_prim->name();
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
return std::make_shared<abstract::Shape>(in_shape);
}

TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto DropoutGrad_prim = prim->cast<PrimDropoutGradPtr>();
MS_EXCEPTION_IF_NULL(DropoutGrad_prim);
auto op_name = DropoutGrad_prim->name();
auto op_name = prim->name();
auto mask_dtype = input_args[1]->BuildType();
auto dy_dtype = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("mask", mask_dtype, {kTensorType}, op_name);


+ 3
- 6
mindspore/core/ops/grad/group_conv2d_grad_input.cc View File

@@ -114,8 +114,7 @@ void GroupConv2DGradInput::set_input_shape(const std::vector<int64_t> &input_sha
}

std::vector<int64_t> GroupConv2DGradInput::get_input_shape() const {
auto value_ptr = GetAttr(kInputShape);
return GetValue<std::vector<int64_t>>(value_ptr);
return GetValue<std::vector<int64_t>>(GetAttr(kInputShape));
}

void GroupConv2DGradInput::set_format(const Format &format) {
@@ -147,14 +146,12 @@ bool GroupConv2DGradInput::get_has_bias() const {
AbstractBasePtr GroupConv2DGradInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto group_prim = primitive->cast<PrimGroupConv2DGradInputPtr>();
MS_EXCEPTION_IF_NULL(group_prim);
auto prim_name = group_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("group_conv_2D_infer", input_args.size(), kGreaterEqual, 2, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);

// Infer shape
auto shape = group_prim->get_input_shape();
auto shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kInputShape));

// Infer type
auto type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();


+ 2
- 3
mindspore/core/ops/grad/max_pool_grad.cc View File

@@ -21,9 +21,8 @@ namespace mindspore {
namespace ops {
AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto MaxPoolGrad_prim = primitive->cast<PrimMaxPoolGradPtr>();
MS_EXCEPTION_IF_NULL(MaxPoolGrad_prim);
auto op_name = MaxPoolGrad_prim->name();
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name);
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();


+ 1
- 3
mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc View File

@@ -30,9 +30,7 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto sigmoid_prim = primitive->cast<PrimSigmoidCrossEntropyWithLogitsGradPtr>();
MS_EXCEPTION_IF_NULL(sigmoid_prim);
auto prim_name = sigmoid_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", input_args.size(), kEqual, 3,
prim_name);



+ 1
- 3
mindspore/core/ops/grad/smooth_l1_loss_grad.cc View File

@@ -36,9 +36,7 @@ float SmoothL1LossGrad::get_beta() const {
AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto smooth_prim = primitive->cast<PrimSmoothL1LossGradPtr>();
MS_EXCEPTION_IF_NULL(smooth_prim);
auto prim_name = smooth_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name);

// Infer shape


+ 1
- 3
mindspore/core/ops/hashtable_lookup.cc View File

@@ -24,12 +24,10 @@ namespace ops {
AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto HashtableLookup_prim = primitive->cast<PrimHashtableLookupPtr>();
MS_EXCEPTION_IF_NULL(HashtableLookup_prim);
for (auto input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
auto op_name = HashtableLookup_prim->name();
auto op_name = primitive->name();
std::vector<int64_t> hits_shape;
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
hits_shape.push_back(input[0]);


+ 3
- 8
mindspore/core/ops/l2_normalize.cc View File

@@ -29,10 +29,7 @@ void L2Normalize::set_axis(const std::vector<int64_t> &axis) { AddAttr(kAxis, Ma

void L2Normalize::set_epsilon(const float epsilon) { AddAttr(kEpsilon, MakeValue(epsilon)); }

std::vector<int64_t> L2Normalize::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> L2Normalize::get_axis() const { return GetValue<std::vector<int64_t>>(GetAttr(kAxis)); }

float L2Normalize::get_epsilon() const {
auto value_ptr = GetAttr(kEpsilon);
@@ -42,9 +39,7 @@ float L2Normalize::get_epsilon() const {
AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim = primitive->cast<PrimL2NormalizePtr>();
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
@@ -53,7 +48,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_rank = SizeToLong(x_shape.size());
auto axiss = prim->get_axis();
auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
for (auto &axis : axiss) {
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
}


+ 1
- 3
mindspore/core/ops/less.cc View File

@@ -27,9 +27,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto less_prim = primitive->cast<PrimLessPtr>();
MS_EXCEPTION_IF_NULL(less_prim);
auto op_name = less_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 1
- 3
mindspore/core/ops/less_equal.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto equal_prim = primitive->cast<PrimLessEqualPtr>();
MS_EXCEPTION_IF_NULL(equal_prim);
auto op_name = equal_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 1
- 3
mindspore/core/ops/logical_and.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto logicaland_prim = primitive->cast<PrimLogicalAndPtr>();
MS_EXCEPTION_IF_NULL(logicaland_prim);
auto op_name = logicaland_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 2
- 6
mindspore/core/ops/logical_not.cc View File

@@ -24,18 +24,14 @@ namespace ops {
namespace {
abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto LogicalNot_prim = primitive->cast<PrimLogicalNotPtr>();
MS_EXCEPTION_IF_NULL(LogicalNot_prim);
auto op_name = LogicalNot_prim->name();
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
return std::make_shared<abstract::Shape>(in_shape);
}

TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto LogicalNot_prim = prim->cast<PrimLogicalNotPtr>();
MS_EXCEPTION_IF_NULL(LogicalNot_prim);
auto op_name = LogicalNot_prim->name();
auto op_name = prim->name();
auto infer_dtype = input_args[0]->BuildType();
std::set<TypePtr> local_bool = {kBool};
return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name);


+ 1
- 3
mindspore/core/ops/logical_or.cc View File

@@ -29,9 +29,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto logicalor_prim = primitive->cast<PrimLogicalOrPtr>();
MS_EXCEPTION_IF_NULL(logicalor_prim);
auto op_name = logicalor_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 1
- 3
mindspore/core/ops/lrn.cc View File

@@ -77,9 +77,7 @@ void LRN::Init(const int64_t depth_radius, const float bias, const float alpha,
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto lrn_prim = primitive->cast<PrimLrn>();
MS_EXCEPTION_IF_NULL(lrn_prim);
auto prim_name = lrn_prim->name();
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name);
return std::make_shared<abstract::Shape>(in_shape);


+ 3
- 11
mindspore/core/ops/lsh_projection.cc View File

@@ -26,20 +26,12 @@ void LshProjection::set_type(const LshProjectionType &type) {
AddAttr(kType, MakeValue(swi));
}

LshProjectionType LshProjection::get_type() const {
auto value_ptr = GetAttr(kType);
return LshProjectionType(GetValue<int64_t>(value_ptr));
}
LshProjectionType LshProjection::get_type() const { return LshProjectionType(GetValue<int64_t>(GetAttr(kType))); }

AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto LshProjection_prim = primitive->cast<PrimLshProjectionPtr>();
MS_EXCEPTION_IF_NULL(LshProjection_prim);
// if (input_args.size() != 2 && input_args.size() != 3) {
// MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << input_args.size() << " is given.";
// }
auto op_name = LshProjection_prim->name();
auto op_name = primitive->name();
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name);
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name);
CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name);
@@ -53,7 +45,7 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
}

std::vector<int64_t> out_shape;
switch ((int64_t)LshProjection_prim->get_type()) {
switch ((int64_t)LshProjectionType(GetValue<int64_t>(primitive->GetAttr(kType)))) {
case (int64_t)LshProjectionType::SPARSE:
out_shape.push_back(input0[0]);
break;


+ 21
- 35
mindspore/core/ops/lstm.cc View File

@@ -19,28 +19,34 @@
namespace mindspore {
namespace ops {
namespace {
int64_t get_good_ld(const int64_t dim, const int64_t type_size) {
int64_t ld = ((dim + (64 / type_size) - 1) / (64 / type_size)) * (64 / type_size);
if (ld * 256 == 0) {
return ld + 64 / type_size;
}
return ld;
}

AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
// infer shape
MS_EXCEPTION_IF_NULL(primitive);
auto lstm_prim = primitive->cast<PrimLstmPtr>();
MS_EXCEPTION_IF_NULL(lstm_prim);
auto prim_name = lstm_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name);
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name);
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name);

int64_t input_x_size = lstm_prim->get_input_size();
int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("x_shape[2]", x_input_shape[2], kEqual, input_x_size, prim_name);

CheckAndConvertUtils::CheckInteger("h_shape.size()", h_input_shape.size(), kEqual, 3, prim_name);
CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, lstm_prim->name());
CheckAndConvertUtils::Check("h_shape", h_input_shape, kEqual, "c_shape", c_input_shape, prim_name);

int64_t num_layers = lstm_prim->get_num_layers();
int64_t num_directions = lstm_prim->get_num_directions();
int64_t hidden_size = lstm_prim->get_hidden_size();
int64_t input_size = lstm_prim->get_input_size();
int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers));
int64_t num_directions = GetValue<int64_t>(primitive->GetAttr(kNumDirections));
int64_t hidden_size = GetValue<int64_t>(primitive->GetAttr(kHidden_size));
int64_t input_size = input_x_size;
CheckAndConvertUtils::CheckInteger("h_shape[0]", h_input_shape[0], kEqual, num_layers * num_directions, prim_name);
CheckAndConvertUtils::CheckInteger("h_shape[1]", h_input_shape[1], kEqual, x_input_shape[1], prim_name);
CheckAndConvertUtils::CheckInteger("h_shape[2]", h_input_shape[2], kEqual, hidden_size, prim_name);
@@ -48,8 +54,8 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
std::vector<int64_t> y_shape = {x_input_shape[0], x_input_shape[1], hidden_size * num_directions};

int64_t type_size = 4;
int64_t gates_ws_ld = lstm_prim->get_good_ld(hidden_size * 4, type_size);
int64_t states_ws_ld = lstm_prim->get_good_ld(std::max(hidden_size, input_size), type_size);
int64_t gates_ws_ld = get_good_ld(hidden_size * 4, type_size);
int64_t states_ws_ld = get_good_ld(std::max(hidden_size, input_size), type_size);
int64_t ws_gates_size = num_layers * num_directions * x_input_shape[0] * x_input_shape[1] * gates_ws_ld * type_size;
int64_t ws_states_size =
(num_layers + 1) * num_directions * (x_input_shape[0] + 1) * x_input_shape[1] * states_ws_ld * type_size;
@@ -99,26 +105,17 @@ void LSTM::set_input_size(const int64_t input_size) {
CheckAndConvertUtils::CheckInteger(kInput_size, input_size, kGreaterThan, 0, this->name());
AddAttr(kInput_size, MakeValue(input_size));
}
int64_t LSTM::get_input_size() const {
auto value_ptr = this->GetAttr(kInput_size);
return GetValue<int64_t>(value_ptr);
}
int64_t LSTM::get_input_size() const { return GetValue<int64_t>(GetAttr(kInput_size)); }
void LSTM::set_hidden_size(const int64_t hidden_size) {
CheckAndConvertUtils::CheckInteger(kHidden_size, hidden_size, kGreaterThan, 0, this->name());
AddAttr(kHidden_size, MakeValue(hidden_size));
}
int64_t LSTM::get_hidden_size() const {
auto value_ptr = this->GetAttr(kHidden_size);
return GetValue<int64_t>(value_ptr);
}
int64_t LSTM::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); }
void LSTM::set_num_layers(const int64_t num_layers) {
CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name());
AddAttr(kNumLayers, MakeValue(num_layers));
}
int64_t LSTM::get_num_layers() const {
auto value_ptr = this->GetAttr(kNumLayers);
return GetValue<int64_t>(value_ptr);
}
int64_t LSTM::get_num_layers() const { return GetValue<int64_t>(GetAttr(kNumLayers)); }
void LSTM::set_has_bias(const bool has_bias) { AddAttr(kHasBias, MakeValue(has_bias)); }
bool LSTM::get_has_bias() const {
auto value_ptr = this->GetAttr(kHasBias);
@@ -138,10 +135,7 @@ bool LSTM::get_bidirectional() const {
return GetValue<bool>(value_ptr);
}
void LSTM::set_num_directions(const int64_t num_directions) { AddAttr(kNumDirections, MakeValue(num_directions)); }
int64_t LSTM::get_num_directions() const {
auto value_ptr = this->GetAttr(kNumDirections);
return GetValue<int64_t>(value_ptr);
}
int64_t LSTM::get_num_directions() const { return GetValue<int64_t>(GetAttr(kNumDirections)); }
void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); }

float LSTM::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); }
@@ -167,14 +161,6 @@ void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64
this->set_zoneout_hidden(zoneout_hidden);
}

int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) {
int64_t ld = ((dim + (64 / type_size) - 1) / (64 / type_size)) * (64 / type_size);
if (ld * 256 == 0) {
return ld + 64 / type_size;
}
return ld;
}

AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(LstmInfer(primitive, input_args));


+ 1
- 3
mindspore/core/ops/matrix_diag.cc View File

@@ -29,9 +29,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto matrixdiag_prim = primitive->cast<PrimMatrixDiagPtr>();
MS_EXCEPTION_IF_NULL(matrixdiag_prim);
auto prim_name = matrixdiag_prim->name();
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto assist_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name);


+ 4
- 16
mindspore/core/ops/max_pool.cc View File

@@ -31,37 +31,25 @@ void MaxPool::set_pad_mode(const PadMode &pad_mode) {
this->AddAttr(kPadMode, MakeValue(swi));
}

PadMode MaxPool::get_pad_mode() const {
auto value_ptr = GetAttr(kPadMode);
return PadMode(GetValue<int64_t>(value_ptr));
}
PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue<int64_t>(GetAttr(kPadMode))); }
void MaxPool::set_kernel_size(const std::vector<int64_t> &kernel_size) {
this->AddAttr(kKernelSize,
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name())));
}

std::vector<int64_t> MaxPool::get_kernel_size() const {
auto value_ptr = GetAttr(kKernelSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> MaxPool::get_kernel_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kKernelSize)); }
void MaxPool::set_strides(const std::vector<int64_t> &strides) {
this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name())));
}

std::vector<int64_t> MaxPool::get_strides() const {
auto value_ptr = GetAttr(kStrides);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> MaxPool::get_strides() const { return GetValue<std::vector<int64_t>>(GetAttr(kStrides)); }

void MaxPool::set_format(const Format &format) {
int64_t f = format;
this->AddAttr(kFormat, MakeValue(f));
}

Format MaxPool::get_format() const {
auto value_ptr = GetAttr(kFormat);
return Format(GetValue<int64_t>(value_ptr));
}
Format MaxPool::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }

void MaxPool::set_pad(const std::vector<int64_t> &pad) { this->AddAttr(kPad, MakeValue(pad)); }



+ 1
- 3
mindspore/core/ops/maximum.cc View File

@@ -25,9 +25,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto maximum_prim = primitive->cast<PrimMaximumPtr>();
MS_EXCEPTION_IF_NULL(maximum_prim);
auto op_name = maximum_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 1
- 3
mindspore/core/ops/merge.cc View File

@@ -28,9 +28,7 @@ namespace ops {
AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto Merge_prim = primitive->cast<PrimMergePtr>();
MS_EXCEPTION_IF_NULL(Merge_prim);
auto op_name = Merge_prim->name();
auto op_name = primitive->name();
auto inputs_type = input_args[0]->BuildType()->cast<TuplePtr>()->elements();
auto inputs_shape = input_args[0]->BuildShape()->cast<abstract::TupleShapePtr>()->shape();
std::map<std::string, TypePtr> args;


+ 4
- 8
mindspore/core/ops/mfcc.cc View File

@@ -24,16 +24,15 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto mfcc_prim = primitive->cast<PrimMfccPtr>();
MS_EXCEPTION_IF_NULL(mfcc_prim);
auto prim_name = mfcc_prim->name();
auto prim_name = primitive->name();
auto first_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
auto second_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name);
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1], mfcc_prim->get_dct_coeff_num()};
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],
GetValue<int64_t>(primitive->GetAttr(kDctCoeffNum))};
return std::make_shared<abstract::Shape>(out_shape);
}

@@ -83,10 +82,7 @@ int64_t Mfcc::get_filter_bank_channel_num() const {

void Mfcc::set_dct_coeff_num(const int64_t dct_coeff_num) { this->AddAttr(kDctCoeffNum, MakeValue(dct_coeff_num)); }

int64_t Mfcc::get_dct_coeff_num() const {
auto value_ptr = this->GetAttr(kDctCoeffNum);
return GetValue<int64_t>(value_ptr);
}
int64_t Mfcc::get_dct_coeff_num() const { return GetValue<int64_t>(GetAttr(kDctCoeffNum)); }

AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {


+ 0
- 3
mindspore/core/ops/non_max_suppression.cc View File

@@ -31,9 +31,6 @@ void NonMaxSuppression::Init(const int64_t center_point_box) { this->set_center_

AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto non_max_suppression_prim = primitive->cast<PrimNonMaxSuppressionPtr>();
MS_EXCEPTION_IF_NULL(non_max_suppression_prim);
MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime.";
return std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{});
}


+ 4
- 11
mindspore/core/ops/one_hot.cc View File

@@ -25,17 +25,12 @@ namespace ops {
void OneHot::Init(const int64_t axis) { this->set_axis(axis); }
void OneHot::set_axis(const int64_t axis) { this->AddAttr(kAxis, MakeValue(axis)); }

int64_t OneHot::get_axis() const {
auto value_ptr = this->GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
int64_t OneHot::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
namespace {
abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto OneHot_prim = primitive->cast<PrimOneHotPtr>();
MS_EXCEPTION_IF_NULL(OneHot_prim);
auto op_name = OneHot_prim->name();
int64_t axis = OneHot_prim->get_axis();
auto op_name = primitive->name();
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name);
auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue());
@@ -50,9 +45,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve

TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto OneHot_prim = prim->cast<PrimOneHotPtr>();
MS_EXCEPTION_IF_NULL(OneHot_prim);
auto op_name = OneHot_prim->name();
auto op_name = prim->name();
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name);
CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name);
std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()},


+ 1
- 3
mindspore/core/ops/ones_like.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto OnesLike_prim = primitive->cast<PrimOnesLikePtr>();
MS_EXCEPTION_IF_NULL(OnesLike_prim);
auto prim_name = OnesLike_prim->name();
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
return std::make_shared<abstract::Shape>(input_shape);


+ 3
- 8
mindspore/core/ops/pack.cc View File

@@ -50,23 +50,18 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve

void Pack::set_axis(const int64_t &axis) { AddAttr(kAxis, MakeValue(axis)); }

int64_t Pack::get_axis() const {
auto value_ptr = this->GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
int64_t Pack::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }

void Pack::Init(const int64_t &axis) { this->set_axis(axis); }

AbstractBasePtr PackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto pack_prim = primitive->cast<PrimPackPtr>();
MS_EXCEPTION_IF_NULL(pack_prim);
auto prim_name = pack_prim->name();
auto prim_name = primitive->name();

auto x_shapes = input_args[0]->BuildShape()->cast<abstract::TupleShapePtr>()->shape();
auto x_types = input_args[0]->BuildType()->cast<TuplePtr>()->elements();
auto all_shape = _get_pack_shape(x_shapes, x_types, pack_prim->get_axis(), prim_name);
auto all_shape = _get_pack_shape(x_shapes, x_types, GetValue<int64_t>(primitive->GetAttr(kAxis)), prim_name);
auto tensor_type = x_types[0]->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();


+ 3
- 6
mindspore/core/ops/pad.cc View File

@@ -23,10 +23,8 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto pad_prim = primitive->cast<PrimPadPtr>();
MS_EXCEPTION_IF_NULL(pad_prim);
auto prim_name = pad_prim->name();
auto paddings_attr = pad_prim->get_paddings();
auto prim_name = primitive->name();
auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad");
CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()),
prim_name);
@@ -59,8 +57,7 @@ void Pad::set_paddings(const std::vector<std::vector<int64_t>> &paddings) {
this->AddAttr(kPaddings, MakeValue(paddings));
}
std::vector<std::vector<int64_t>> Pad::get_paddings() const {
auto value_ptr = GetAttr(kPaddings);
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
return GetValue<std::vector<std::vector<int64_t>>>(GetAttr(kPaddings));
}
AbstractBasePtr PadInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {


+ 1
- 3
mindspore/core/ops/pow.cc View File

@@ -24,9 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto pow_prim = primitive->cast<PrimPowPtr>();
MS_EXCEPTION_IF_NULL(pow_prim);
auto op_name = pow_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



+ 8
- 19
mindspore/core/ops/prior_box.cc View File

@@ -24,10 +24,7 @@ namespace mindspore {
namespace ops {
void PriorBox::set_min_sizes(const std::vector<int64_t> &min_sizes) { this->AddAttr(kMinSizes, MakeValue(min_sizes)); }

std::vector<int64_t> PriorBox::get_min_sizes() const {
auto value_ptr = GetAttr(kMinSizes);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> PriorBox::get_min_sizes() const { return GetValue<std::vector<int64_t>>(GetAttr(kMinSizes)); }

void PriorBox::set_max_sizes(const std::vector<int64_t> &max_sizes) { this->AddAttr(kMaxSizes, MakeValue(max_sizes)); }

@@ -40,10 +37,7 @@ void PriorBox::set_aspect_ratios(const std::vector<float> &aspect_ratios) {
this->AddAttr(kAspectRatios, MakeValue(aspect_ratios));
}

std::vector<float> PriorBox::get_aspect_ratios() const {
auto value_ptr = GetAttr(kAspectRatios);
return GetValue<std::vector<float>>(value_ptr);
}
std::vector<float> PriorBox::get_aspect_ratios() const { return GetValue<std::vector<float>>(GetAttr(kAspectRatios)); }

void PriorBox::set_variances(const std::vector<float> &variances) { this->AddAttr(kVariances, MakeValue(variances)); }

@@ -89,10 +83,7 @@ bool PriorBox::get_clip() const {

void PriorBox::set_flip(const bool flip) { this->AddAttr(kFlip, MakeValue(flip)); }

bool PriorBox::get_flip() const {
auto value_ptr = GetAttr(kFlip);
return GetValue<bool>(value_ptr);
}
bool PriorBox::get_flip() const { return GetValue<bool>(GetAttr(kFlip)); }

void PriorBox::set_offset(const float offset) { this->AddAttr(kOffset, MakeValue(offset)); }

@@ -121,25 +112,23 @@ void PriorBox::Init(const std::vector<int64_t> &min_sizes, const std::vector<int
AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto PriorBox_prim = primitive->cast<PrimPriorBoxPtr>();
MS_EXCEPTION_IF_NULL(PriorBox_prim);
auto op_name = PriorBox_prim->name();
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
std::vector<float> different_aspect_ratios{1.0f};
auto aspect_ratios = PriorBox_prim->get_aspect_ratios();
auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios));
for (int64_t i = 0; i < (int64_t)aspect_ratios.size(); i++) {
float ratio = aspect_ratios[i];
bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(),
[&](float v) { return abs(ratio - v) < 1e-6; });
if (!exist) {
different_aspect_ratios.emplace_back(ratio);
if (PriorBox_prim->get_flip()) {
if (GetValue<bool>(primitive->GetAttr(kFlip))) {
different_aspect_ratios.emplace_back(1.0f / ratio);
}
}
}
int64_t num_priors_box =
PriorBox_prim->get_min_sizes().size() * different_aspect_ratios.size() + PriorBox_prim->get_max_sizes().size();
auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes));
int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size();
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
int64_t h = input[0] * input[1] * num_priors_box * 4;
std::vector<int64_t> output_shape{1, h, 1, 2};


+ 5
- 10
mindspore/core/ops/quant_dtype_cast.cc View File

@@ -24,10 +24,7 @@ int64_t QuantDTypeCast::get_src_t() const {
return GetValue<int64_t>(value_ptr);
}
void QuantDTypeCast::set_dst_t(const int64_t dst_t) { AddAttr(kDstT, MakeValue(dst_t)); }
int64_t QuantDTypeCast::get_dst_t() const {
auto value_ptr = this->GetAttr(kDstT);
return GetValue<int64_t>(value_ptr);
}
int64_t QuantDTypeCast::get_dst_t() const { return GetValue<int64_t>(GetAttr(kDstT)); }
void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
this->set_src_t(src_t);
this->set_dst_t(dst_t);
@@ -35,16 +32,14 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto QuantDTypeCast_prim = primitive->cast<PrimQuantDTypeCastPtr>();
MS_EXCEPTION_IF_NULL(QuantDTypeCast_prim);
auto op_name = QuantDTypeCast_prim->name();
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_type);
MS_ASSERT(input_type->element() == TypeIdToType(TypeId(QuantDTypeCast_prim->get_dst_t())));
auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT));
MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type)));
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(QuantDTypeCast_prim->get_dst_t())),
input_shape);
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape);
}
REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast);
} // namespace ops


+ 5
- 11
mindspore/core/ops/range.cc View File

@@ -34,10 +34,7 @@ int64_t Range::get_d_type() const {

void Range::set_start(const int64_t start) { this->AddAttr(kStart, MakeValue(start)); }

int64_t Range::get_start() const {
auto value_ptr = GetAttr(kStart);
return GetValue<int64_t>(value_ptr);
}
int64_t Range::get_start() const { return GetValue<int64_t>(GetAttr(kStart)); }

void Range::set_limit(const int64_t limit) { this->AddAttr(kLimit, MakeValue(limit)); }

@@ -63,10 +60,7 @@ void Range::Init(const int64_t d_type, const int64_t start, const int64_t limit,
AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim = primitive->cast<PrimRangePtr>();
MS_EXCEPTION_IF_NULL(prim);
int64_t shape_size = 0;
TypeId dtype;
if (input_args.size() == 3) {
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
MS_EXCEPTION_IF_NULL(input_args[1]->BuildValue());
@@ -74,7 +68,7 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
auto start_tensor = input_args[0]->BuildValue()->cast<tensor::TensorPtr>();
auto limit_tensor = input_args[1]->BuildValue()->cast<tensor::TensorPtr>();
auto delta_tensor = input_args[2]->BuildValue()->cast<tensor::TensorPtr>();
dtype = static_cast<TypeId>(start_tensor->data_type_c());
auto dtype = start_tensor->data_type();
switch (dtype) {
case kNumberTypeInt:
case kNumberTypeInt32: {
@@ -97,9 +91,9 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
}
}
} else {
int64_t start = prim->get_start();
int64_t limit = prim->get_limit();
int64_t delta = prim->get_delta();
int64_t start = GetValue<int64_t>(primitive->GetAttr(kStart));
int64_t limit = GetValue<int64_t>(primitive->GetAttr(kLimit));
int64_t delta = GetValue<int64_t>(primitive->GetAttr(kDelta));
shape_size =
std::max(static_cast<int64_t>(std::ceil(LongToDouble(limit - start) / delta)), static_cast<int64_t>(0));
}


+ 1
- 3
mindspore/core/ops/rank.cc View File

@@ -21,9 +21,7 @@ namespace ops {
namespace {
TypePtr RankInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto Rank_prim = prim->cast<PrimRankPtr>();
MS_EXCEPTION_IF_NULL(Rank_prim);
auto op_name = Rank_prim->name();
auto op_name = prim->name();
auto infer_dtype = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name);
return kTypeNone;


+ 1
- 3
mindspore/core/ops/reciprocal.cc View File

@@ -29,9 +29,7 @@ namespace ops {
AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto reciprocal_prim = primitive->cast<PrimReciprocalPtr>();
MS_EXCEPTION_IF_NULL(reciprocal_prim);
auto prim_name = reciprocal_prim->name();
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


+ 3
- 8
mindspore/core/ops/reduce.cc View File

@@ -70,13 +70,11 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto axis_value = input_args[1]->BuildValue();

MS_EXCEPTION_IF_NULL(primitive);
auto reduce_prim = primitive->cast<PrimReducePtr>();
MS_EXCEPTION_IF_NULL(reduce_prim);
auto prim_name = reduce_prim->name();
auto prim_name = primitive->name();
auto input_x_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name);

auto keep_dims = reduce_prim->get_keep_dims();
auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims));
auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name);

return std::make_shared<abstract::Shape>(out_shape);
@@ -93,10 +91,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &

void Reduce::set_keep_dims(const bool keep_dims) { this->AddAttr(kKeepDims, MakeValue(keep_dims)); }

bool Reduce::get_keep_dims() const {
auto value_ptr = GetAttr(kKeepDims);
return GetValue<bool>(value_ptr);
}
bool Reduce::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }

void Reduce::Init(const bool keep_dims) { this->set_keep_dims(keep_dims); }



+ 3
- 8
mindspore/core/ops/resize_bilinear.cc View File

@@ -27,10 +27,7 @@ namespace mindspore {
namespace ops {
void ResizeBilinear::set_size(const std::vector<int64_t> &size) { this->AddAttr(kSize, MakeValue(size)); }

std::vector<int64_t> ResizeBilinear::get_size() const {
auto value_ptr = GetAttr(kSize);
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::vector<int64_t> ResizeBilinear::get_size() const { return GetValue<std::vector<int64_t>>(GetAttr(kSize)); }

void ResizeBilinear::set_align_corners(const bool align_corners) {
this->AddAttr(kAlignCorners, MakeValue(align_corners));
@@ -48,9 +45,7 @@ void ResizeBilinear::Init(const std::vector<int64_t> &size, const bool align_cor
AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto resize_prim = primitive->cast<PrimResizeBilinearPtr>();
MS_EXCEPTION_IF_NULL(resize_prim);
auto prim_name = resize_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name);

// Infer shape
@@ -58,7 +53,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name);
std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]};
auto size = resize_prim->get_size();
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));
out_shape.insert(out_shape.end(), size.begin(), size.end());

// Infer type


+ 4
- 9
mindspore/core/ops/reverse_sequence.cc View File

@@ -30,10 +30,7 @@ void ReverseSequence::Init(const int64_t seq_dim, const int64_t batch_dim) {
void ReverseSequence::set_seq_dim(const int64_t seq_dim) { this->AddAttr(kSeqDim, MakeValue(seq_dim)); }
void ReverseSequence::set_batch_dim(const int64_t batch_dim) { this->AddAttr(kBatchDim, MakeValue(batch_dim)); }

int64_t ReverseSequence::get_seq_dim() const {
auto value_ptr = this->GetAttr(kSeqDim);
return GetValue<int64_t>(value_ptr);
}
int64_t ReverseSequence::get_seq_dim() const { return GetValue<int64_t>(GetAttr(kSeqDim)); }
int64_t ReverseSequence::get_batch_dim() const {
auto value_ptr = this->GetAttr(kBatchDim);
return GetValue<int64_t>(value_ptr);
@@ -41,9 +38,7 @@ int64_t ReverseSequence::get_batch_dim() const {
AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto reverse_prim = primitive->cast<PrimReverseSequence>();
MS_EXCEPTION_IF_NULL(reverse_prim);
auto prim_name = reverse_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
@@ -53,8 +48,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto seq_lengths =
CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name);
auto seq_dim = reverse_prim->get_seq_dim();
auto batch_dim = reverse_prim->get_batch_dim();
auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim));
auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim));
CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name);
CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kLessEqual, input_shape.size(), prim_name);
CheckAndConvertUtils::CheckInteger("batch_dim", batch_dim, kNotEqual, seq_dim, prim_name);


+ 1
- 3
mindspore/core/ops/reverse_v2.cc View File

@@ -24,9 +24,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto reverseV2_prim = primitive->cast<PrimReverseV2Ptr>();
MS_EXCEPTION_IF_NULL(reverseV2_prim);
auto prim_name = reverseV2_prim->name();
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
return std::make_shared<abstract::Shape>(x_shape);
}


+ 3
- 8
mindspore/core/ops/rfft.cc View File

@@ -24,13 +24,11 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto rfft_prim = primitive->cast<PrimRfftPtr>();
MS_EXCEPTION_IF_NULL(rfft_prim);
auto prim_name = rfft_prim->name();
auto prim_name = primitive->name();
auto first_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
auto out_shape = first_input_shape;
out_shape[out_shape.size() - 1] = rfft_prim->get_fft_length() / 2 + 1;
out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1;
out_shape.push_back(2);
return std::make_shared<abstract::Shape>(out_shape);
}
@@ -47,10 +45,7 @@ void Rfft::Init(const int64_t fft_length) { this->set_fft_length(fft_length); }

void Rfft::set_fft_length(const int64_t fft_length) { this->AddAttr(kFftLength, MakeValue(fft_length)); }

int64_t Rfft::get_fft_length() const {
auto value_ptr = this->GetAttr(kFftLength);
return GetValue<int64_t>(value_ptr);
}
int64_t Rfft::get_fft_length() const { return GetValue<int64_t>(GetAttr(kFftLength)); }

AbstractBasePtr RfftInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {


+ 4
- 9
mindspore/core/ops/roi_pooling.cc View File

@@ -27,10 +27,7 @@ namespace mindspore {
namespace ops {
void ROIPooling::set_pooled_h(const int64_t pooled_h) { this->AddAttr(kPooledH, MakeValue(pooled_h)); }

int64_t ROIPooling::get_pooled_h() const {
auto value_ptr = GetAttr(kPooledH);
return GetValue<int64_t>(value_ptr);
}
int64_t ROIPooling::get_pooled_h() const { return GetValue<int64_t>(GetAttr(kPooledH)); }

void ROIPooling::set_pooled_w(const int64_t pooled_w) { this->AddAttr(kPooledW, MakeValue(pooled_w)); }

@@ -54,9 +51,7 @@ void ROIPooling::Init(const int64_t pooled_h, const int64_t pooled_w, const floa
AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto roi_prim = primitive->cast<PrimROIPoolingPtr>();
MS_EXCEPTION_IF_NULL(roi_prim);
auto prim_name = roi_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("roi_pooling_infer", input_args.size(), kEqual, 2, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);
@@ -65,8 +60,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
auto output_data_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();

// Infer shape
auto new_h = roi_prim->get_pooled_h();
auto new_w = roi_prim->get_pooled_w();
auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH));
auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW));
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name);


+ 1
- 3
mindspore/core/ops/rsqrt.cc View File

@@ -29,9 +29,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto rsqrt_prim = primitive->cast<PrimRsqrtPtr>();
MS_EXCEPTION_IF_NULL(rsqrt_prim);
auto prim_name = rsqrt_prim->name();
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name);
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(in_shape);


+ 1
- 3
mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc View File

@@ -29,9 +29,7 @@ namespace ops {
AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto sigmoid_prim = primitive->cast<PrimSigmoidCrossEntropyWithLogitsPtr>();
MS_EXCEPTION_IF_NULL(sigmoid_prim);
auto prim_name = sigmoid_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer", input_args.size(), kEqual, 2,
prim_name);



+ 1
- 3
mindspore/core/ops/skip_gram.cc View File

@@ -23,9 +23,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto SkipGram_prim = primitive->cast<PrimSkipGramPtr>();
MS_EXCEPTION_IF_NULL(SkipGram_prim);
auto prim_name = SkipGram_prim->name();
auto prim_name = primitive->name();
if (input_args.size() != 1) {
MS_LOG(ERROR) << "Skip Gram should have one input";
}


+ 1
- 3
mindspore/core/ops/smooth_l1_loss.cc View File

@@ -36,9 +36,7 @@ float SmoothL1Loss::get_beta() const {
AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto smooth_prim = primitive->cast<PrimSmoothL1LossPtr>();
MS_EXCEPTION_IF_NULL(smooth_prim);
auto prim_name = smooth_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);

// Infer shape


+ 1
- 3
mindspore/core/ops/softmax_cross_entropy_with_logits.cc View File

@@ -29,9 +29,7 @@ namespace ops {
AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto softmax_prim = primitive->cast<PrimSoftmaxCrossEntropyWithLogitsPtr>();
MS_EXCEPTION_IF_NULL(softmax_prim);
auto prim_name = softmax_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", input_args.size(), kEqual, 2,
prim_name);



+ 4
- 7
mindspore/core/ops/space_to_batch.cc View File

@@ -28,15 +28,13 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto spacetobatch_prim = primitive->cast<PrimSpaceToBatchPtr>();
MS_EXCEPTION_IF_NULL(spacetobatch_prim);
auto prim_name = spacetobatch_prim->name();
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name);
std::vector<int64_t> output_shape(input_shape.size());
auto block_shape_vector = spacetobatch_prim->get_block_size();
auto paddings = spacetobatch_prim->get_paddings();
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto paddings = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
for (size_t i = 0; i < 2; i++) {
auto padded = output_shape[i + 2] + paddings[i][0] + paddings[i][1];
CheckAndConvertUtils::CheckInteger("padded shape", padded % block_shape_vector.size(), kEqual, 0, prim_name);
@@ -77,8 +75,7 @@ void SpaceToBatch::set_block_size(const std::vector<int64_t> block_size) {
}

std::vector<int64_t> SpaceToBatch::get_block_size() const {
auto value_ptr = GetAttr(kBlockSize);
return GetValue<std::vector<int64_t>>(value_ptr);
return GetValue<std::vector<int64_t>>(GetAttr(kBlockSize));
}

void SpaceToBatch::Init(const std::vector<int64_t> block_size, const std::vector<std::vector<int64_t>> &paddings) {


+ 4
- 7
mindspore/core/ops/space_to_batch_nd.cc View File

@@ -28,16 +28,14 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto space_prim = primitive->cast<PrimSpaceToBatchNDPtr>();
MS_EXCEPTION_IF_NULL(space_prim);
auto prim_name = space_prim->name();
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;
const int64_t offset = 2;
auto block_shape = space_prim->get_block_shape();
auto padding = space_prim->get_paddings();
auto block_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockShape));
auto padding = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
int64_t size = block_shape.size();
for (int64_t i = 0; i < size; i++) {
int64_t padded = out_shape[i + offset] + padding[i][0] + padding[i][1];
@@ -87,8 +85,7 @@ void SpaceToBatchND::set_block_shape(std::vector<int64_t> block_shape) {
}

std::vector<int64_t> SpaceToBatchND::get_block_shape() const {
auto value_ptr = GetAttr(kBlockShape);
return GetValue<std::vector<int64_t>>(value_ptr);
return GetValue<std::vector<int64_t>>(GetAttr(kBlockShape));
}

void SpaceToBatchND::Init(std::vector<int64_t> block_shape, std::vector<std::vector<int64_t>> paddings) {


+ 3
- 8
mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc View File

@@ -31,18 +31,13 @@ void SparseSoftmaxCrossEntropyWithLogits::set_is_grad(const bool is_grad) {
this->AddAttr(kIsGrad, MakeValue(is_grad));
}

bool SparseSoftmaxCrossEntropyWithLogits::get_is_grad() const {
auto value_ptr = GetAttr(kIsGrad);
return GetValue<bool>(value_ptr);
}
bool SparseSoftmaxCrossEntropyWithLogits::get_is_grad() const { return GetValue<bool>(GetAttr(kIsGrad)); }

AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto sparse_softmax_cross_entropy_prim = primitive->cast<PrimSparseSoftmaxCrossEntropyWithLogitsPtr>();
MS_EXCEPTION_IF_NULL(sparse_softmax_cross_entropy_prim);
auto prim_name = sparse_softmax_cross_entropy_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
@@ -51,7 +46,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
std::vector<int64_t> output_shape;
if (sparse_softmax_cross_entropy_prim->get_is_grad() != 0) {
if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) {
output_shape = input_shape;
} else {
output_shape.push_back(1);


+ 1
- 3
mindspore/core/ops/sparse_to_dense.cc View File

@@ -27,9 +27,7 @@ namespace ops {
AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto spasetodense_prim = primitive->cast<PrimSparseToDensePtr>();
MS_EXCEPTION_IF_NULL(spasetodense_prim);
auto prim_name = spasetodense_prim->name();
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);


+ 1
- 3
mindspore/core/ops/squared_difference.cc View File

@@ -27,9 +27,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto squared_prim = primitive->cast<PrimSquaredDifferencePtr>();
MS_EXCEPTION_IF_NULL(squared_prim);
auto op_name = squared_prim->name();
auto op_name = primitive->name();
return BroadCastInferShape(op_name, input_args);
}



Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save