Browse Source

remove ConvertShapePtrToShape function, use ConvertShapePtrToShapeMap instead

pull/15218/head
simson 4 years ago
parent
commit
894622865a
100 changed files with 154 additions and 223 deletions
  1. +1
    -2
      mindspore/core/ops/abs.cc
  2. +4
    -5
      mindspore/core/ops/adam.cc
  3. +2
    -4
      mindspore/core/ops/addn.cc
  4. +1
    -1
      mindspore/core/ops/apply_momentum.cc
  5. +1
    -1
      mindspore/core/ops/arg_max.cc
  6. +1
    -1
      mindspore/core/ops/arg_min.cc
  7. +1
    -1
      mindspore/core/ops/asin.cc
  8. +1
    -2
      mindspore/core/ops/assert.cc
  9. +1
    -3
      mindspore/core/ops/assign_add.cc
  10. +1
    -1
      mindspore/core/ops/atan.cc
  11. +1
    -3
      mindspore/core/ops/audio_spectrogram.cc
  12. +1
    -1
      mindspore/core/ops/avg_pool.cc
  13. +6
    -7
      mindspore/core/ops/batch_norm.cc
  14. +4
    -6
      mindspore/core/ops/batch_norm_fold.cc
  15. +1
    -1
      mindspore/core/ops/batch_to_space.cc
  16. +1
    -1
      mindspore/core/ops/batch_to_space_nd.cc
  17. +2
    -2
      mindspore/core/ops/bias_add.cc
  18. +3
    -4
      mindspore/core/ops/binary_cross_entropy.cc
  19. +1
    -1
      mindspore/core/ops/broadcast.cc
  20. +1
    -1
      mindspore/core/ops/broadcast_to.cc
  21. +1
    -1
      mindspore/core/ops/ceil.cc
  22. +2
    -4
      mindspore/core/ops/concat.cc
  23. +1
    -2
      mindspore/core/ops/constant_of_shape.cc
  24. +2
    -2
      mindspore/core/ops/conv2d.cc
  25. +1
    -2
      mindspore/core/ops/conv2d_transpose.cc
  26. +1
    -2
      mindspore/core/ops/cos.cc
  27. +1
    -1
      mindspore/core/ops/crop.cc
  28. +1
    -5
      mindspore/core/ops/custom_extract_features.cc
  29. +1
    -1
      mindspore/core/ops/depth_to_space.cc
  30. +2
    -2
      mindspore/core/ops/depthwise_conv2d.cc
  31. +3
    -3
      mindspore/core/ops/detection_post_process.cc
  32. +1
    -1
      mindspore/core/ops/dropout.cc
  33. +1
    -2
      mindspore/core/ops/elu.cc
  34. +1
    -1
      mindspore/core/ops/expand_dims.cc
  35. +3
    -3
      mindspore/core/ops/fake_quant_with_min_max_vars.cc
  36. +3
    -3
      mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc
  37. +1
    -2
      mindspore/core/ops/fft_imag.cc
  38. +1
    -1
      mindspore/core/ops/fft_real.cc
  39. +1
    -1
      mindspore/core/ops/flatten.cc
  40. +1
    -2
      mindspore/core/ops/floor.cc
  41. +1
    -1
      mindspore/core/ops/fusion/avg_pool_fusion.cc
  42. +3
    -4
      mindspore/core/ops/fusion/full_connection.cc
  43. +1
    -1
      mindspore/core/ops/fusion/max_pool_fusion.cc
  44. +1
    -1
      mindspore/core/ops/fusion/slice_fusion.cc
  45. +2
    -2
      mindspore/core/ops/gather_d.cc
  46. +2
    -3
      mindspore/core/ops/gather_nd.cc
  47. +1
    -2
      mindspore/core/ops/gelu.cc
  48. +2
    -4
      mindspore/core/ops/grad/batch_norm_grad.cc
  49. +1
    -1
      mindspore/core/ops/grad/bias_add_grad.cc
  50. +3
    -4
      mindspore/core/ops/grad/binary_cross_entropy_grad.cc
  51. +1
    -2
      mindspore/core/ops/grad/dropout_grad.cc
  52. +1
    -3
      mindspore/core/ops/grad/max_pool_grad.cc
  53. +3
    -3
      mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc
  54. +3
    -3
      mindspore/core/ops/grad/smooth_l1_loss_grad.cc
  55. +1
    -2
      mindspore/core/ops/hashtable_lookup.cc
  56. +1
    -1
      mindspore/core/ops/l2_normalize.cc
  57. +1
    -1
      mindspore/core/ops/log.cc
  58. +1
    -2
      mindspore/core/ops/logical_not.cc
  59. +1
    -1
      mindspore/core/ops/lrn.cc
  60. +3
    -3
      mindspore/core/ops/lsh_projection.cc
  61. +3
    -3
      mindspore/core/ops/lstm.cc
  62. +2
    -2
      mindspore/core/ops/mat_mul.cc
  63. +2
    -3
      mindspore/core/ops/matrix_diag.cc
  64. +1
    -1
      mindspore/core/ops/max_pool.cc
  65. +2
    -4
      mindspore/core/ops/mfcc.cc
  66. +1
    -1
      mindspore/core/ops/one_hot.cc
  67. +1
    -3
      mindspore/core/ops/ones_like.cc
  68. +2
    -2
      mindspore/core/ops/op_utils.cc
  69. +2
    -2
      mindspore/core/ops/pack.cc
  70. +1
    -1
      mindspore/core/ops/pad.cc
  71. +2
    -2
      mindspore/core/ops/prelu.cc
  72. +1
    -2
      mindspore/core/ops/prior_box.cc
  73. +1
    -2
      mindspore/core/ops/quant_dtype_cast.cc
  74. +1
    -2
      mindspore/core/ops/reciprocal.cc
  75. +1
    -2
      mindspore/core/ops/reduce.cc
  76. +1
    -2
      mindspore/core/ops/resize_bilinear.cc
  77. +2
    -4
      mindspore/core/ops/reverse_sequence.cc
  78. +1
    -2
      mindspore/core/ops/reverse_v2.cc
  79. +1
    -3
      mindspore/core/ops/rfft.cc
  80. +2
    -3
      mindspore/core/ops/roi_pooling.cc
  81. +1
    -1
      mindspore/core/ops/round.cc
  82. +1
    -1
      mindspore/core/ops/rsqrt.cc
  83. +1
    -1
      mindspore/core/ops/scalar_summary.cc
  84. +2
    -4
      mindspore/core/ops/scatter_nd.cc
  85. +2
    -2
      mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc
  86. +1
    -1
      mindspore/core/ops/sin.cc
  87. +1
    -2
      mindspore/core/ops/skip_gram.cc
  88. +2
    -2
      mindspore/core/ops/smooth_l1_loss.cc
  89. +2
    -4
      mindspore/core/ops/softmax_cross_entropy_with_logits.cc
  90. +1
    -2
      mindspore/core/ops/space_to_batch.cc
  91. +1
    -1
      mindspore/core/ops/space_to_batch_nd.cc
  92. +1
    -2
      mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc
  93. +1
    -2
      mindspore/core/ops/sparse_to_dense.cc
  94. +1
    -1
      mindspore/core/ops/squeeze.cc
  95. +2
    -5
      mindspore/core/ops/stack.cc
  96. +1
    -1
      mindspore/core/ops/strided_slice.cc
  97. +1
    -1
      mindspore/core/ops/tan.cc
  98. +2
    -5
      mindspore/core/ops/tensor_list_from_tensor.cc
  99. +2
    -5
      mindspore/core/ops/tensor_list_stack.cc
  100. +1
    -1
      mindspore/core/ops/tensor_summary.cc

+ 1
- 2
mindspore/core/ops/abs.cc View File

@@ -30,11 +30,10 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }




+ 4
- 5
mindspore/core/ops/adam.cc View File

@@ -26,11 +26,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve
auto prim_name = primitive->name(); auto prim_name = primitive->name();


// infer shape // infer shape
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShape("var_shape", input_args[0]->GetShapeTrack(), prim_name);
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShape("m_shape", input_args[1]->GetShapeTrack(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[2]->GetShapeTrack(), prim_name);
auto grad_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("grad_shape", input_args[9]->GetShapeTrack(), prim_name);
auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[9]->GetShapeTrack())[kShape];
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name); CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name); CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name); CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);


+ 2
- 4
mindspore/core/ops/addn.cc View File

@@ -38,15 +38,13 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0); MS_EXCEPTION_IF_NULL(element0);
auto element0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];


std::map<std::string, TypePtr> types; std::map<std::string, TypePtr> types;
types.emplace("element0", element0->BuildType()); types.emplace("element0", element0->BuildType());
for (size_t i = 1; i < elements.size(); ++i) { for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i); std::string elementi = "element" + std::to_string(i);
auto elementi_shape =
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
prim_name); prim_name);
for (size_t j = 0; j < element0_shape.size(); ++j) { for (size_t j = 0; j < element0_shape.size(); ++j) {


+ 1
- 1
mindspore/core/ops/apply_momentum.cc View File

@@ -60,7 +60,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);


// Infer shape // Infer shape
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[0]->BuildShape(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];


// Infer type // Infer type
auto v_tensor_type = input_args[0]->BuildType(); auto v_tensor_type = input_args[0]->BuildType();


+ 1
- 1
mindspore/core/ops/arg_max.cc View File

@@ -23,7 +23,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size()); auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); CheckAndConvertUtils::CheckInRange<int64_t>("argmax axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
axis = axis < 0 ? axis + x_rank : axis; axis = axis < 0 ? axis + x_rank : axis;


+ 1
- 1
mindspore/core/ops/arg_min.cc View File

@@ -42,7 +42,7 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive


// Infer shape // Infer shape
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size()); auto x_rank = SizeToLong(x_shape.size());
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name); CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeLeft, {-x_rank, x_rank}, prim_name);
if (axis < 0) { if (axis < 0) {


+ 1
- 1
mindspore/core/ops/asin.cc View File

@@ -29,7 +29,7 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("Asin_infer", input_args.size(), kEqual, 1, prim_name);


// Infer Shape // Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto infer_shape = std::make_shared<abstract::Shape>(x_shape); auto infer_shape = std::make_shared<abstract::Shape>(x_shape);


// Infer Type // Infer Type


+ 1
- 2
mindspore/core/ops/assert.cc View File

@@ -47,8 +47,7 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
} }
condition = TypeIdToType(kNumberTypeBool); condition = TypeIdToType(kNumberTypeBool);
} else { } else {
auto condition_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto condition_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("condition's rank", condition_shape[0], kLessEqual, 1, op_name);
if (condition_shape[0] == 1) { if (condition_shape[0] == 1) {
auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c()); auto condition_value = reinterpret_cast<bool *>(input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c());


+ 1
- 3
mindspore/core/ops/assign_add.cc View File

@@ -25,9 +25,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto value_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("value_shape", input_args[1]->BuildShape(), prim_name);
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(value_shape); return std::make_shared<abstract::Shape>(value_shape);
} }




+ 1
- 1
mindspore/core/ops/atan.cc View File

@@ -27,7 +27,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);


// Infer Shape // Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto infer_shape = std::make_shared<abstract::Shape>(x_shape); auto infer_shape = std::make_shared<abstract::Shape>(x_shape);


// Infer Type // Infer Type


+ 1
- 3
mindspore/core/ops/audio_spectrogram.cc View File

@@ -30,9 +30,7 @@ namespace {
abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive, abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (input_shape.size() != 2) { if (input_shape.size() != 2) {
MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions";
} }


+ 1
- 1
mindspore/core/ops/avg_pool.cc View File

@@ -82,7 +82,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};


+ 6
- 7
mindspore/core/ops/batch_norm.cc View File

@@ -75,20 +75,19 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name); CheckAndConvertUtils::CheckInteger("batch_norm_infer", input_args.size(), kEqual, 5, prim_name);


auto input_x = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
input_x = {input_x[0], input_x[3], input_x[1], input_x[2]}; input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
} }
auto scale = CheckAndConvertUtils::ConvertShapePtrToShape("scale", input_args[1]->BuildShape(), prim_name);
auto bias = CheckAndConvertUtils::ConvertShapePtrToShape("bias", input_args[2]->BuildShape(), prim_name);
auto mean = CheckAndConvertUtils::ConvertShapePtrToShape("mean", input_args[3]->BuildShape(), prim_name);
auto variance = CheckAndConvertUtils::ConvertShapePtrToShape("variance", input_args[4]->BuildShape(), prim_name);
auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->BuildShape())[kShape];


std::vector<int64_t> input_shape_norm; std::vector<int64_t> input_shape_norm;
if (format == NCHW) { if (format == NCHW) {
input_shape_norm =
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
} else { } else {
input_shape_norm.push_back(input_x[0]); input_shape_norm.push_back(input_x[0]);
input_shape_norm.push_back(input_x[3]); input_shape_norm.push_back(input_x[3]);


+ 4
- 6
mindspore/core/ops/batch_norm_fold.cc View File

@@ -68,12 +68,10 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShape("mean_shape", input_args[1]->BuildShape(), op_name);
auto variance_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("variance_shape", input_args[2]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto global_step_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("global_step_shape", input_args[3]->BuildShape(), op_name);
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto global_step_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name); CheckAndConvertUtils::Check("mean_shape", mean_shape, kEqual, "gamma_shape", variance_shape, op_name);
CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name); CheckAndConvertUtils::Check("mean_shape[0]", mean_shape[0], kEqual, "input channel", x_shape[1], op_name);
CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("global step shape len", global_step_shape.size(), kEqual, 1, op_name);


+ 1
- 1
mindspore/core/ops/batch_to_space.cc View File

@@ -55,7 +55,7 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types,
prim_name); prim_name);


auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); auto block_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops)); auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));


+ 1
- 1
mindspore/core/ops/batch_to_space_nd.cc View File

@@ -29,7 +29,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
auto out_shape = x_shape; auto out_shape = x_shape;
int64_t block_shape_prod = 1; int64_t block_shape_prod = 1;


+ 2
- 2
mindspore/core/ops/bias_add.cc View File

@@ -30,8 +30,8 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name); CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name);
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));


+ 3
- 4
mindspore/core/ops/binary_cross_entropy.cc View File

@@ -34,10 +34,9 @@ abstract::ShapePtr BinaryCrossEntroyInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); CheckAndConvertUtils::CheckInRange("binary_cross_entropy_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto weight_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
std::vector<int64_t> infer_shape; std::vector<int64_t> infer_shape;
if (weight_shape.size() < 1) { if (weight_shape.size() < 1) {


+ 1
- 1
mindspore/core/ops/broadcast.cc View File

@@ -50,7 +50,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
// infer shape // infer shape
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
// infer type // infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::vector<TypePtr> output_types; std::vector<TypePtr> output_types;


+ 1
- 1
mindspore/core/ops/broadcast_to.cc View File

@@ -24,7 +24,7 @@ abstract::ShapePtr BroadcastToInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto value_ptr = primitive->GetAttr(kShape); auto value_ptr = primitive->GetAttr(kShape);
auto input_x = GetValue<std::vector<int64_t>>(value_ptr); auto input_x = GetValue<std::vector<int64_t>>(value_ptr);
int64_t outer_dim_offset = input_x.size() - x_shape.size(); int64_t outer_dim_offset = input_x.size() - x_shape.size();


+ 1
- 1
mindspore/core/ops/ceil.cc View File

@@ -31,7 +31,7 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto infer_type = input_args[0]->BuildType(); auto infer_type = input_args[0]->BuildType();
auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name()); auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());


+ 2
- 4
mindspore/core/ops/concat.cc View File

@@ -43,8 +43,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("concat element num", elements.size(), kGreaterEqual, 1, prim_name);
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>(); auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0); MS_EXCEPTION_IF_NULL(element0);
auto element0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("element0 shape", element0->BuildShape(), prim_name);
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
auto element0_rank = SizeToLong(element0_shape.size()); auto element0_rank = SizeToLong(element0_shape.size());
auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); auto axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank}, CheckAndConvertUtils::CheckInRange<int64_t>("Concat axis", axis, kIncludeBoth, {-element0_rank - 1, element0_rank},
@@ -56,8 +55,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
int64_t all_shp = element0_shape[axis]; int64_t all_shp = element0_shape[axis];
for (size_t i = 1; i < elements.size(); ++i) { for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i); std::string elementi = "element" + std::to_string(i);
auto elementi_shape =
CheckAndConvertUtils::ConvertShapePtrToShape(elementi + " shape", elements[i]->BuildShape(), prim_name);
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(), CheckAndConvertUtils::CheckInteger(elementi + " shape rank", elementi_shape.size(), kEqual, element0_shape.size(),
prim_name); prim_name);
for (int64_t j = 0; j < element0_rank; ++j) { for (int64_t j = 0; j < element0_rank; ++j) {


+ 1
- 2
mindspore/core/ops/constant_of_shape.cc View File

@@ -24,8 +24,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape"); CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kEqual, 1, "ConstantOfShape");
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "ConstantOfShape");
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape); return std::make_shared<abstract::Shape>(input_shape);
} }




+ 2
- 2
mindspore/core/ops/conv2d.cc View File

@@ -79,8 +79,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};


+ 1
- 2
mindspore/core/ops/conv2d_transpose.cc View File

@@ -28,8 +28,7 @@ namespace {
abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[3]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape); return std::make_shared<abstract::Shape>(input_shape);
} }




+ 1
- 2
mindspore/core/ops/cos.cc View File

@@ -24,11 +24,10 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }




+ 1
- 1
mindspore/core/ops/crop.cc View File

@@ -49,7 +49,7 @@ AbstractBasePtr CropInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
// infer shape // infer shape
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), prim_name);
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
// infer type // infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape); return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);


+ 1
- 5
mindspore/core/ops/custom_extract_features.cc View File

@@ -24,18 +24,14 @@ namespace ops {
AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[0]);
// auto input = input_args[0];

// Infer type // Infer type
auto output0_type = kInt32; auto output0_type = kInt32;
auto output1_type = kFloat32; auto output1_type = kFloat32;


// Infer shape // Infer shape
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto string_num = input_shape[0]; auto string_num = input_shape[0];
if (string_num == 0) { if (string_num == 0) {
out_shape.push_back(1); out_shape.push_back(1);


+ 1
- 1
mindspore/core/ops/depth_to_space.cc View File

@@ -54,7 +54,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
auto input_x = input_args[0]->cast<abstract::AbstractTensorPtr>(); auto input_x = input_args[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(input_x); MS_EXCEPTION_IF_NULL(input_x);


auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};


+ 2
- 2
mindspore/core/ops/depthwise_conv2d.cc View File

@@ -119,8 +119,8 @@ abstract::ShapePtr DepthWiseConv2DInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); CheckAndConvertUtils::CheckInRange<size_t>("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};


+ 3
- 3
mindspore/core/ops/detection_post_process.cc View File

@@ -120,9 +120,9 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
auto boxes = input_args[0]; auto boxes = input_args[0];
auto scores = input_args[1]; auto scores = input_args[1];
auto anchors = input_args[2]; auto anchors = input_args[2];
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShape("boxes_shape", boxes->BuildShape(), prim_name);
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShape("scores_shape", scores->BuildShape(), prim_name);
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShape("anchors_shape", anchors->BuildShape(), prim_name);
auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape];
auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape];
auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]}; boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]};


+ 1
- 1
mindspore/core/ops/dropout.cc View File

@@ -43,7 +43,7 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv
CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("dropout_infer", input_args.size(), kEqual, 1, prim_name);


// Infer shape // Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterEqual, 1, prim_name);
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end()); out_shape.insert(out_shape.end(), x_shape.begin(), x_shape.end());


+ 1
- 2
mindspore/core/ops/elu.cc View File

@@ -31,11 +31,10 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }




+ 1
- 1
mindspore/core/ops/expand_dims.cc View File

@@ -36,7 +36,7 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
// Infer shape // Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto dim_val = GetValue<int64_t>(input_args[1]->BuildValue()); auto dim_val = GetValue<int64_t>(input_args[1]->BuildValue());
auto rank = x_shape.size(); auto rank = x_shape.size();
CheckAndConvertUtils::CheckInRange<int64_t>("axis", dim_val, kIncludeBoth, {-rank - 1, rank}, prim_name); CheckAndConvertUtils::CheckInRange<int64_t>("axis", dim_val, kIncludeBoth, {-rank - 1, rank}, prim_name);


+ 3
- 3
mindspore/core/ops/fake_quant_with_min_max_vars.cc View File

@@ -29,9 +29,9 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), prim_name);
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kGreaterEqual, 1, prim_name);
CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name); CheckAndConvertUtils::Check("min_shape", min_shape, kEqual, "max_shape", max_shape, prim_name);
CheckAndConvertUtils::CheckInteger("min_shape", min_shape.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("min_shape", min_shape.size(), kEqual, 1, prim_name);


+ 3
- 3
mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc View File

@@ -44,9 +44,9 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShape("min_shape", input_args[1]->BuildShape(), op_name);
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShape("max_shape", input_args[2]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name); CheckAndConvertUtils::CheckInteger("x rank", (int64_t)x_shape.size(), kGreaterThan, 1, op_name);
CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name); CheckAndConvertUtils::Check("min shape", min_shape, kEqual, "max shape", max_shape, op_name);
CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("min shape", (int64_t)min_shape.size(), kEqual, 1, op_name);


+ 1
- 2
mindspore/core/ops/fft_imag.cc View File

@@ -24,8 +24,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
in_shape.pop_back(); in_shape.pop_back();
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }


+ 1
- 1
mindspore/core/ops/fft_real.cc View File

@@ -33,7 +33,7 @@ AbstractBasePtr FftRealInfer(const abstract::AnalysisEnginePtr &, const Primitiv
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto out_dtype = kFloat32; auto out_dtype = kFloat32;
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
out_shape.pop_back(); out_shape.pop_back();
return std::make_shared<abstract::AbstractTensor>(out_dtype, std::make_shared<abstract::Shape>(out_shape)); return std::make_shared<abstract::AbstractTensor>(out_dtype, std::make_shared<abstract::Shape>(out_shape));
} }


+ 1
- 1
mindspore/core/ops/flatten.cc View File

@@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("input args size", input_args.size(), kGreaterEqual, 1, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto prod = 1; auto prod = 1;
int64_t size = x_shape.size(); int64_t size = x_shape.size();
for (int64_t i = 1; i < size; i++) { for (int64_t i = 1; i < size; i++) {


+ 1
- 2
mindspore/core/ops/floor.cc View File

@@ -28,11 +28,10 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }




+ 1
- 1
mindspore/core/ops/fusion/avg_pool_fusion.cc View File

@@ -53,7 +53,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};


+ 3
- 4
mindspore/core/ops/fusion/full_connection.cc View File

@@ -53,8 +53,8 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
MS_EXCEPTION_IF_NULL(input_args[1]); MS_EXCEPTION_IF_NULL(input_args[1]);
auto input0 = input_args[0]; auto input0 = input_args[0];
auto input1 = input_args[1]; auto input1 = input_args[1];
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input0->BuildShape(), prim_name);
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input1->BuildShape(), prim_name);
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input0->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input1->BuildShape())[kShape];
auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); auto prim_axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias)); auto has_bias = GetValue<bool>(primitive->GetAttr(kHasBias));
if (has_bias) { if (has_bias) {
@@ -78,8 +78,7 @@ AbstractBasePtr FullConnectionInfer(const abstract::AnalysisEnginePtr &, const P
new_k = input1_shape[1]; new_k = input1_shape[1];
} }
if (has_bias) { if (has_bias) {
auto input2_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), prim_name);
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
if (input2_shape[0] != input1_shape[0]) { if (input2_shape[0] != input1_shape[0]) {
MS_EXCEPTION(ValueError) << "Bias size invalid"; MS_EXCEPTION(ValueError) << "Bias size invalid";
} }


+ 1
- 1
mindspore/core/ops/fusion/max_pool_fusion.cc View File

@@ -53,7 +53,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};


+ 1
- 1
mindspore/core/ops/fusion/slice_fusion.cc View File

@@ -33,7 +33,7 @@ AbstractBasePtr SliceFusionInfer(const abstract::AnalysisEnginePtr &, const Prim
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape_len = (int64_t)x_shape.size(); auto x_shape_len = (int64_t)x_shape.size();
auto begin_v = input_args[1]->BuildValue(); auto begin_v = input_args[1]->BuildValue();
auto size_v = input_args[2]->BuildValue(); auto size_v = input_args[2]->BuildValue();


+ 2
- 2
mindspore/core/ops/gather_d.cc View File

@@ -29,8 +29,8 @@ abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::v
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
int64_t x_rank = x_shape.size(); int64_t x_rank = x_shape.size();
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name); CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name);
auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue()); auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue());


+ 2
- 3
mindspore/core/ops/gather_nd.cc View File

@@ -32,9 +32,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[1]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto input_rank = input_shape.size(); auto input_rank = input_shape.size();
auto indices_rank = indices_shape.size(); auto indices_rank = indices_shape.size();
CheckAndConvertUtils::CheckInteger("Input of indices data", input_rank, kGreaterEqual, CheckAndConvertUtils::CheckInteger("Input of indices data", input_rank, kGreaterEqual,


+ 1
- 2
mindspore/core/ops/gelu.cc View File

@@ -28,8 +28,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr GeLUInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_x", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape); return std::make_shared<abstract::Shape>(input_shape);
} }




+ 2
- 4
mindspore/core/ops/grad/batch_norm_grad.cc View File

@@ -47,13 +47,11 @@ bool BatchNormGrad::get_is_training() const {
AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr BatchNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[1]); MS_EXCEPTION_IF_NULL(input_args[1]);
MS_EXCEPTION_IF_NULL(input_args[2]); MS_EXCEPTION_IF_NULL(input_args[2]);
MS_EXCEPTION_IF_NULL(input_args[3]); MS_EXCEPTION_IF_NULL(input_args[3]);
auto y_backprop_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("y_backprop_shape", input_args[0]->BuildShape(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->BuildShape(), op_name);
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape); CheckAndConvertUtils::Check("BatchNorm y_backprop_shape", y_backprop_shape, kEqual, "BatchNorm x_shape", x_shape);


auto dx = input_args[1]->Broaden(); auto dx = input_args[1]->Broaden();


+ 1
- 1
mindspore/core/ops/grad/bias_add_grad.cc View File

@@ -46,7 +46,7 @@ AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const Prim
MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[0]);


// Infer shape // Infer shape
auto inshape = CheckAndConvertUtils::ConvertShapePtrToShape("inshape", input_args[0]->BuildShape(), prim_name);
auto inshape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
for (size_t i = 0; i < inshape.size() - 1; i++) { for (size_t i = 0; i < inshape.size() - 1; i++) {
inshape[i] = 1; inshape[i] = 1;
} }


+ 3
- 4
mindspore/core/ops/grad/binary_cross_entropy_grad.cc View File

@@ -27,10 +27,9 @@ abstract::ShapePtr BinaryCrossEntroyGradInferShape(const PrimitivePtr &primitive
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto weight_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("weight_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name); CheckAndConvertUtils::Check("x shape", x_shape, kEqual, "y shape", y_shape, prim_name);
if (weight_shape.size() < 1) { if (weight_shape.size() < 1) {
CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name); CheckAndConvertUtils::Check("y shape", y_shape, kEqual, "weight shape", weight_shape, prim_name);


+ 1
- 2
mindspore/core/ops/grad/dropout_grad.cc View File

@@ -35,8 +35,7 @@ namespace {
abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive, abstract::ShapePtr DropoutGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }




+ 1
- 3
mindspore/core/ops/grad/max_pool_grad.cc View File

@@ -21,10 +21,8 @@ namespace mindspore {
namespace ops { namespace ops {
AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr MaxPoolGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue()); MS_EXCEPTION_IF_NULL(input_args[0]->BuildValue());
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x1_shape", input_args[0]->BuildShape(), op_name);
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type); MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element(); auto element = tensor_type->element();


+ 3
- 3
mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc View File

@@ -35,9 +35,9 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
prim_name); prim_name);


// Infer Shape // Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dout_shape", input_args[2]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dout_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError); CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);




+ 3
- 3
mindspore/core/ops/grad/smooth_l1_loss_grad.cc View File

@@ -40,9 +40,9 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name); CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", input_args.size(), kEqual, 3, prim_name);


// Infer shape // Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name);
auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[1]->BuildShape(), prim_name);
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShape("dloss", input_args[2]->BuildShape(), prim_name);
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError); CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError);




+ 1
- 2
mindspore/core/ops/hashtable_lookup.cc View File

@@ -27,9 +27,8 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const
for (auto input : input_args) { for (auto input : input_args) {
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
} }
auto op_name = primitive->name();
std::vector<int64_t> hits_shape; std::vector<int64_t> hits_shape;
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
hits_shape.push_back(input[0]); hits_shape.push_back(input[0]);


auto value_type = input_args[2]->BuildType(); auto value_type = input_args[2]->BuildType();


+ 1
- 1
mindspore/core/ops/l2_normalize.cc View File

@@ -46,7 +46,7 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim
} }
const std::set<TypePtr> valid_types = {kFloat16, kFloat32}; const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size()); auto x_rank = SizeToLong(x_shape.size());
auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); auto axiss = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
for (auto &axis : axiss) { for (auto &axis : axiss) {


+ 1
- 1
mindspore/core/ops/log.cc View File

@@ -24,7 +24,7 @@ namespace mindspore {
namespace ops { namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Log");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape); return std::make_shared<abstract::Shape>(x_shape);
} }




+ 1
- 2
mindspore/core/ops/logical_not.cc View File

@@ -24,8 +24,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr LogicalNotInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }




+ 1
- 1
mindspore/core/ops/lrn.cc View File

@@ -78,7 +78,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 4, prim_name);
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }


+ 3
- 3
mindspore/core/ops/lsh_projection.cc View File

@@ -32,14 +32,14 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name);
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name);
auto input0 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name); CheckAndConvertUtils::CheckInteger("input0_shape", input0.size(), kEqual, 2, op_name);
CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name); CheckAndConvertUtils::CheckInteger("input0_shape_dimen_1", input0[1], kLessEqual, 32, op_name);
CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("input1_shape", input1.size(), kGreaterEqual, 1, op_name);


if (input_args.size() == 3) { if (input_args.size() == 3) {
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShape("input2_shape", input_args[2]->BuildShape(), op_name);
auto input2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name); CheckAndConvertUtils::CheckInteger("input2_shape", input2.size(), kEqual, 1, op_name);
CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name); CheckAndConvertUtils::CheckInteger("input2_shape_dimen_0", input2[0], kEqual, input1[0], op_name);
} }


+ 3
- 3
mindspore/core/ops/lstm.cc View File

@@ -32,9 +32,9 @@ AbstractBasePtr LstmInfer(const PrimitivePtr &primitive, const std::vector<Abstr
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("lstm_prim_infer", input_args.size(), kEqual, 4, prim_name);
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("h_shape", input_args[1]->BuildShape(), prim_name);
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("c_shape", input_args[2]->BuildShape(), prim_name);
auto x_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto h_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto c_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];


int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size)); int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size));
CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name); CheckAndConvertUtils::CheckInteger("x_shape.size()", x_input_shape.size(), kEqual, 3, prim_name);


+ 2
- 2
mindspore/core/ops/mat_mul.cc View File

@@ -26,8 +26,8 @@ abstract::ShapePtr MatMulInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name); CheckAndConvertUtils::CheckInteger("matmul_infer_input", input_args.size(), kEqual, 2, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA)); auto trans_a = GetValue<bool>(primitive->GetAttr(kTransposeA));
auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB)); auto trans_b = GetValue<bool>(primitive->GetAttr(kTransposeB));




+ 2
- 3
mindspore/core/ops/matrix_diag.cc View File

@@ -30,9 +30,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto assist_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("assist_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto assist_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];


CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name); CheckAndConvertUtils::CheckInteger("assist rank", (int64_t)assist_shape.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank", CheckAndConvertUtils::Check("x_shape rank", (int64_t)x_shape.size() + 1, kLessEqual, "assist rank",


+ 1
- 1
mindspore/core/ops/max_pool.cc View File

@@ -82,7 +82,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat))); auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) { if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};


+ 2
- 4
mindspore/core/ops/mfcc.cc View File

@@ -25,10 +25,8 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto first_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
auto second_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("second_input_shape", input_args[1]->BuildShape(), prim_name);
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name);
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1], std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],


+ 1
- 1
mindspore/core/ops/one_hot.cc View File

@@ -31,7 +31,7 @@ abstract::ShapePtr OneHotInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name(); auto op_name = primitive->name();
int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis)); int64_t axis = GetValue<int64_t>(primitive->GetAttr(kAxis));
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name); CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-1, SizeToLong(in_shape.size())}, op_name);
auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue()); auto depth_val = GetValue<int64_t>(input_args[1]->BuildValue());
CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name); CheckAndConvertUtils::CheckInteger("depth", depth_val, kGreaterEqual, 0, op_name);


+ 1
- 3
mindspore/core/ops/ones_like.cc View File

@@ -28,9 +28,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(input_shape); return std::make_shared<abstract::Shape>(input_shape);
} }




+ 2
- 2
mindspore/core/ops/op_utils.cc View File

@@ -27,8 +27,8 @@ namespace mindspore {
namespace ops { namespace ops {
abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
MS_LOG(INFO) << "Do infer shape for op " << op_name; MS_LOG(INFO) << "Do infer shape for op " << op_name;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->GetShapeTrack(), op_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
if (x_shape == y_shape) { if (x_shape == y_shape) {
return std::make_shared<abstract::Shape>(x_shape); return std::make_shared<abstract::Shape>(x_shape);
} }


+ 2
- 2
mindspore/core/ops/pack.cc View File

@@ -23,7 +23,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
std::string name) { std::string name) {
CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name); CheckAndConvertUtils::CheckInteger("len of input_x", (int64_t)x_shapes.size(), kGreaterEqual, 1, name);
CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name); CheckAndConvertUtils::CheckSubClass("input_x[0]", x_types[0], {TypeIdToType(kObjectTypeTensorType)}, name);
auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape[0]", x_shapes[0], name);
auto output_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[0])[kShape];
int64_t rank_base = output_shape.size(); int64_t rank_base = output_shape.size();
int64_t N = x_shapes.size(); int64_t N = x_shapes.size();
// CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name); // CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeBoth, {-rank_base-1, rank_base}, name);
@@ -37,7 +37,7 @@ std::vector<int64_t> _get_pack_shape(std::vector<BaseShapePtr> x_shapes, std::ve
MS_EXCEPTION_IF_NULL(type0); MS_EXCEPTION_IF_NULL(type0);
CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(), CheckAndConvertUtils::Check("x_type[" + std::to_string(i) + "]", type->type_id(), kEqual, "base", type0->type_id(),
name); name);
auto shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape" + std::to_string(i), x_shapes[i], name);
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shapes[i])[kShape];
if (shape != output_shape) { if (shape != output_shape) {
MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) + MS_EXCEPTION(ValueError) << "For '" + name + "' element " + std::to_string(i) +
"shape in input can't pack with first element."; "shape in input can't pack with first element.";


+ 1
- 1
mindspore/core/ops/pad.cc View File

@@ -25,7 +25,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings)); auto paddings_attr = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kPaddings));
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Pad");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()), CheckAndConvertUtils::CheckInteger("paddings_size", paddings_attr.size(), kEqual, int64_t(2 * x_shape.size()),
prim_name); prim_name);
int64_t size = paddings_attr.size(); int64_t size = paddings_attr.size();


+ 2
- 2
mindspore/core/ops/prelu.cc View File

@@ -25,8 +25,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto x = input_args[0]->BuildShape(); auto x = input_args[0]->BuildShape();
auto w = input_args[1]->BuildShape(); auto w = input_args[1]->BuildShape();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", x, prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", w, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x)[kShape];
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(w)[kShape];


CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kNotEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 1, prim_name);


+ 1
- 2
mindspore/core/ops/prior_box.cc View File

@@ -112,7 +112,6 @@ void PriorBox::Init(const std::vector<int64_t> &min_sizes, const std::vector<int
AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[0]);
std::vector<float> different_aspect_ratios{1.0f}; std::vector<float> different_aspect_ratios{1.0f};
auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios)); auto aspect_ratios = GetValue<std::vector<float>>(primitive->GetAttr(kAspectRatios));
@@ -129,7 +128,7 @@ AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const Primiti
} }
auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes)); auto min_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kMinSizes));
int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size(); int64_t num_priors_box = min_sizes.size() * different_aspect_ratios.size() + min_sizes.size();
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto input = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t h = input[0] * input[1] * num_priors_box * 4; int64_t h = input[0] * input[1] * num_priors_box * 4;
std::vector<int64_t> output_shape{1, h, 1, 2}; std::vector<int64_t> output_shape{1, h, 1, 2};
return std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape); return std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);


+ 1
- 2
mindspore/core/ops/quant_dtype_cast.cc View File

@@ -32,13 +32,12 @@ void QuantDTypeCast::Init(const int64_t src_t, const int64_t dst_t) {
AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]); MS_EXCEPTION_IF_NULL(input_args[0]);
auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>(); auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_type); MS_EXCEPTION_IF_NULL(input_type);
auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT)); auto dst_type = GetValue<int64_t>(primitive->GetAttr(kDstT));
MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type))); MS_ASSERT(input_type->element() == TypeIdToType(TypeId(dst_type)));
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape); return std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId(dst_type)), input_shape);
} }
REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast); REGISTER_PRIMITIVE_C(kNameQuantDTypeCast, QuantDTypeCast);


+ 1
- 2
mindspore/core/ops/reciprocal.cc View File

@@ -34,8 +34,7 @@ AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const Primi
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
// infer shape // infer shape
auto in_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
// infer type // infer type
std::set<TypePtr> valid_x_type = {kTensorType}; std::set<TypePtr> valid_x_type = {kTensorType};
auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name); auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);


+ 1
- 2
mindspore/core/ops/reduce.cc View File

@@ -71,8 +71,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A


MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto input_x_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_x_shape", input_args[0]->BuildShape(), prim_name);
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];


auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims)); auto keep_dims = GetValue<bool>(primitive->GetAttr(kKeepDims));
auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name); auto out_shape = infer_shape_reduce(input_x_shape, axis_value, keep_dims, prim_name);


+ 1
- 2
mindspore/core/ops/resize_bilinear.cc View File

@@ -49,8 +49,7 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("resize_bilinear_infer", input_args.size(), kEqual, 1, prim_name);


// Infer shape // Infer shape
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("input_shape_rank", input_shape.size(), kEqual, 4, prim_name);
std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]}; std::vector<int64_t> out_shape = {input_shape[0], input_shape[1]};
auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize)); auto size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));


+ 2
- 4
mindspore/core/ops/reverse_sequence.cc View File

@@ -44,10 +44,8 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
// infer shape // infer shape
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto seq_lengths =
CheckAndConvertUtils::ConvertShapePtrToShape("seq_lengths", input_args[1]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto seq_lengths = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim)); auto seq_dim = GetValue<int64_t>(primitive->GetAttr(kSeqDim));
auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim)); auto batch_dim = GetValue<int64_t>(primitive->GetAttr(kBatchDim));
CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name); CheckAndConvertUtils::CheckInteger("seq_dim", seq_dim, kLessEqual, input_shape.size(), prim_name);


+ 1
- 2
mindspore/core/ops/reverse_v2.cc View File

@@ -24,8 +24,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape); return std::make_shared<abstract::Shape>(x_shape);
} }




+ 1
- 3
mindspore/core/ops/rfft.cc View File

@@ -24,9 +24,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto first_input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("first_input_shape", input_args[0]->BuildShape(), prim_name);
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto out_shape = first_input_shape; auto out_shape = first_input_shape;
out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1; out_shape[out_shape.size() - 1] = GetValue<int64_t>(primitive->GetAttr(kFftLength)) / 2 + 1;
out_shape.push_back(2); out_shape.push_back(2);


+ 2
- 3
mindspore/core/ops/roi_pooling.cc View File

@@ -62,9 +62,8 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
// Infer shape // Infer shape
auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH)); auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH));
auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW)); auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW));
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShape("roi_shape", input_args[1]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
output_shape.push_back(roi_shape[0]); output_shape.push_back(roi_shape[0]);
output_shape.push_back(new_h); output_shape.push_back(new_h);


+ 1
- 1
mindspore/core/ops/round.cc View File

@@ -23,7 +23,7 @@ namespace mindspore {
namespace ops { namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "round");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape); return std::make_shared<abstract::Shape>(x_shape);
} }




+ 1
- 1
mindspore/core/ops/rsqrt.cc View File

@@ -30,7 +30,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->GetShapeTrack(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("input shape", in_shape.size(), kEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }


+ 1
- 1
mindspore/core/ops/scalar_summary.cc View File

@@ -29,7 +29,7 @@ abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(ShapeVector(1)); return std::make_shared<abstract::Shape>(ShapeVector(1));
} }


+ 2
- 4
mindspore/core/ops/scatter_nd.cc View File

@@ -29,10 +29,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
for (const auto &shape : shape_value_element) { for (const auto &shape : shape_value_element) {
CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd"); CheckAndConvertUtils::CheckInteger("shape value", shape, kGreaterThan, 0, "ScatterNd");
} }
auto indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("indices_shape", input_args[0]->BuildShape(), "ScatterNd");
auto update_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("update_shape", input_args[1]->BuildShape(), "ScatterNd");
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto update_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0], CheckAndConvertUtils::CheckInteger("indices_shape[0] and update_shape[0]", indices_shape[0], kEqual, update_shape[0],
"ScatterNd"); "ScatterNd");
return std::make_shared<abstract::Shape>(shape_value_element); return std::make_shared<abstract::Shape>(shape_value_element);


+ 2
- 2
mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc View File

@@ -34,8 +34,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
prim_name); prim_name);


// Infer shape // Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError); CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);


// Infer type // Infer type


+ 1
- 1
mindspore/core/ops/sin.cc View File

@@ -31,7 +31,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Sin");
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(x_shape); return std::make_shared<abstract::Shape>(x_shape);
} }




+ 1
- 2
mindspore/core/ops/skip_gram.cc View File

@@ -23,7 +23,6 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
if (input_args.size() != 1) { if (input_args.size() != 1) {
MS_LOG(ERROR) << "Skip Gram should have one input"; MS_LOG(ERROR) << "Skip Gram should have one input";
} }
@@ -31,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
if (infer_value == nullptr) { if (infer_value == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime."; MS_LOG(INFO) << "Do infer shape in runtime.";
} }
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("in_shape", input_args[0]->BuildShape(), prim_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(in_shape); return std::make_shared<abstract::Shape>(in_shape);
} }




+ 2
- 2
mindspore/core/ops/smooth_l1_loss.cc View File

@@ -40,8 +40,8 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri
CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name); CheckAndConvertUtils::CheckInteger("smooth_l1_loss_infer", input_args.size(), kEqual, 2, prim_name);


// Infer shape // Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShape("prediction", input_args[0]->BuildShape(), prim_name);
auto target = CheckAndConvertUtils::ConvertShapePtrToShape("target", input_args[0]->BuildShape(), prim_name);
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);


// Infer type // Infer type


+ 2
- 4
mindspore/core/ops/softmax_cross_entropy_with_logits.cc View File

@@ -34,10 +34,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
prim_name); prim_name);


// Infer shape // Infer shape
auto logits_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("logits_shape", input_args[0]->BuildShape(), prim_name);
auto labels_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("labels_shape", input_args[1]->BuildShape(), prim_name);
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto labels_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("logits shape", logits_shape, kEqual, "labels shape", labels_shape, prim_name, TypeError); CheckAndConvertUtils::Check("logits shape", logits_shape, kEqual, "labels shape", labels_shape, prim_name, TypeError);
std::vector<int64_t> loss_shape = {logits_shape[0]}; std::vector<int64_t> loss_shape = {logits_shape[0]};
auto dlogits_shape = logits_shape; auto dlogits_shape = logits_shape;


+ 1
- 2
mindspore/core/ops/space_to_batch.cc View File

@@ -29,8 +29,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("input shape", input_shape.size(), kEqual, 4, prim_name);
std::vector<int64_t> output_shape(input_shape.size()); std::vector<int64_t> output_shape(input_shape.size());
auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize)); auto block_shape_vector = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockSize));


+ 1
- 1
mindspore/core/ops/space_to_batch_nd.cc View File

@@ -29,7 +29,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
auto out_shape = x_shape; auto out_shape = x_shape;
int64_t block_shape_prod = 1; int64_t block_shape_prod = 1;


+ 1
- 2
mindspore/core/ops/sparse_softmax_cross_entropy_with_logits.cc View File

@@ -43,8 +43,7 @@ AbstractBasePtr SparseSoftmaxCrossEntropyWithLogitsInfer(const abstract::Analysi
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
// infer shape // infer shape
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) { if (GetValue<bool>(primitive->GetAttr(kIsGrad)) != 0) {
output_shape = input_shape; output_shape = input_shape;


+ 1
- 2
mindspore/core/ops/sparse_to_dense.cc View File

@@ -33,8 +33,7 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
// infer shape // infer shape
auto dense_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("dense_shape", input_args[3]->BuildShape(), prim_name);
auto dense_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
// infer type // infer type
auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element(); auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape); return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape);


+ 1
- 1
mindspore/core/ops/squeeze.cc View File

@@ -29,7 +29,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis)); auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAxis));
std::vector<int64_t> infer_shape; std::vector<int64_t> infer_shape;


auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto len = SizeToLong(in_shape.size()); auto len = SizeToLong(in_shape.size());
if (axis.empty()) { if (axis.empty()) {
std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape), std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape),


+ 2
- 5
mindspore/core/ops/stack.cc View File

@@ -21,7 +21,6 @@ namespace ops {
namespace { namespace {
abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();


if (input_args.size() != 1) { if (input_args.size() != 1) {
MS_LOG(ERROR) << "Invalid output size:" << input_args.size(); MS_LOG(ERROR) << "Invalid output size:" << input_args.size();
@@ -29,11 +28,9 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
if (input_args.size() < 1) { if (input_args.size() < 1) {
MS_LOG(ERROR) << "Invalid input size " << input_args.size(); MS_LOG(ERROR) << "Invalid input size " << input_args.size();
} }
auto input_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), prim_name);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) { for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) {
auto input_shape_tmp =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[i]->BuildShape(), prim_name);
auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
if (input_shape_tmp.size() != input_shape.size()) { if (input_shape_tmp.size() != input_shape.size()) {
MS_LOG(ERROR) << "All input shape size should be the same!"; MS_LOG(ERROR) << "All input shape size should be the same!";
} }


+ 1
- 1
mindspore/core/ops/strided_slice.cc View File

@@ -108,7 +108,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue(); auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue();
auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v); auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v);


auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t x_rank = x_shape.size(); int64_t x_rank = x_shape.size();
int64_t slice_len = begin_v.size(); int64_t slice_len = begin_v.size();
std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask))); std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));


+ 1
- 1
mindspore/core/ops/tan.cc View File

@@ -33,7 +33,7 @@ AbstractBasePtr TanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("tan_infer", input_args.size(), kEqual, 1, prim_name);


// Infer Shape // Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto infer_shape = std::make_shared<abstract::Shape>(x_shape); auto infer_shape = std::make_shared<abstract::Shape>(x_shape);


// Infer Type // Infer Type


+ 2
- 5
mindspore/core/ops/tensor_list_from_tensor.cc View File

@@ -24,11 +24,8 @@ namespace {
abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive, abstract::ShapePtr TensorListFromTensorInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto input0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input0 shape", input_args[0]->BuildShape(), prim_name);
auto input1_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input1 shape", input_args[1]->BuildShape(), prim_name);
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
if (input0_shape.size() < 1) { if (input0_shape.size() < 1) {
MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!"; MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!";
} }


+ 2
- 5
mindspore/core/ops/tensor_list_stack.cc View File

@@ -52,9 +52,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const
for (const auto &input : input_args) { for (const auto &input : input_args) {
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
} }
auto op_name = primitive->name();
auto input0_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input0_shape", input_args[0]->BuildShape(), op_name);
auto input0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies<int64_t>()); int64_t num = std::accumulate(input0_shape.begin(), input0_shape.end(), 1LL, std::multiplies<int64_t>());
if (num == 0) { if (num == 0) {
MS_LOG(ERROR) << "Try to stack a empty tensorlist!"; MS_LOG(ERROR) << "Try to stack a empty tensorlist!";
@@ -62,8 +60,7 @@ AbstractBasePtr TensorListStackInfer(const abstract::AnalysisEnginePtr &, const
if (input_args[1]->BuildShape() == nullptr) { if (input_args[1]->BuildShape() == nullptr) {
MS_LOG(ERROR) << "ele_shape->data_c() is nullptr"; MS_LOG(ERROR) << "ele_shape->data_c() is nullptr";
} }
auto input1_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input1_shape", input_args[1]->BuildShape(), op_name);
auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
input1_shape.insert(input1_shape.begin(), 1); input1_shape.insert(input1_shape.begin(), 1);
return std::make_shared<abstract::AbstractTensor>(input_args[0]->BuildType(), input1_shape); return std::make_shared<abstract::AbstractTensor>(input_args[0]->BuildType(), input1_shape);
} }


+ 1
- 1
mindspore/core/ops/tensor_summary.cc View File

@@ -29,7 +29,7 @@ abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name(); auto prim_name = primitive->name();
// check // check
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name); CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
return std::make_shared<abstract::Shape>(ShapeVector(1)); return std::make_shared<abstract::Shape>(ShapeVector(1));
} }


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save