Browse Source

change some check type api

pull/13901/head
LianLiguang 5 years ago
parent
commit
d9f4659cfd
100 changed files with 201 additions and 390 deletions
  1. +1
    -2
      mindspore/core/ops/abs.cc
  2. +4
    -8
      mindspore/core/ops/adam.cc
  3. +1
    -2
      mindspore/core/ops/add.cc
  4. +3
    -5
      mindspore/core/ops/addn.cc
  5. +1
    -1
      mindspore/core/ops/apply_momentum.cc
  6. +0
    -3
      mindspore/core/ops/arg_min.cc
  7. +2
    -8
      mindspore/core/ops/asin.cc
  8. +4
    -4
      mindspore/core/ops/assert.cc
  9. +1
    -2
      mindspore/core/ops/assign_add.cc
  10. +3
    -9
      mindspore/core/ops/atan.cc
  11. +5
    -5
      mindspore/core/ops/batch_norm.cc
  12. +2
    -17
      mindspore/core/ops/batch_norm_fold.cc
  13. +2
    -1
      mindspore/core/ops/batch_to_space.cc
  14. +1
    -2
      mindspore/core/ops/bias_add.cc
  15. +2
    -2
      mindspore/core/ops/binary_cross_entropy.cc
  16. +1
    -1
      mindspore/core/ops/broadcast.cc
  17. +3
    -4
      mindspore/core/ops/broadcast_to.cc
  18. +2
    -6
      mindspore/core/ops/ceil.cc
  19. +1
    -3
      mindspore/core/ops/concat.cc
  20. +1
    -2
      mindspore/core/ops/constant.cc
  21. +2
    -7
      mindspore/core/ops/conv2d.cc
  22. +2
    -3
      mindspore/core/ops/conv2d_transpose.cc
  23. +1
    -2
      mindspore/core/ops/cos.cc
  24. +2
    -2
      mindspore/core/ops/custom_extract_features.cc
  25. +3
    -3
      mindspore/core/ops/custom_predict.cc
  26. +3
    -3
      mindspore/core/ops/depthwise_conv2d.cc
  27. +1
    -1
      mindspore/core/ops/detection_post_process.cc
  28. +1
    -2
      mindspore/core/ops/div.cc
  29. +3
    -9
      mindspore/core/ops/dropout.cc
  30. +2
    -3
      mindspore/core/ops/elu.cc
  31. +3
    -8
      mindspore/core/ops/embedding_lookup.cc
  32. +1
    -2
      mindspore/core/ops/equal.cc
  33. +1
    -2
      mindspore/core/ops/exp.cc
  34. +3
    -3
      mindspore/core/ops/expand_dims.cc
  35. +2
    -3
      mindspore/core/ops/fake_quant_with_min_max_vars.cc
  36. +1
    -2
      mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc
  37. +1
    -1
      mindspore/core/ops/fft_imag.cc
  38. +2
    -2
      mindspore/core/ops/fill.cc
  39. +1
    -1
      mindspore/core/ops/flatten.cc
  40. +3
    -4
      mindspore/core/ops/floor.cc
  41. +1
    -2
      mindspore/core/ops/fusion/add_fusion.cc
  42. +1
    -2
      mindspore/core/ops/fusion/pow_fusion.cc
  43. +5
    -5
      mindspore/core/ops/gather.cc
  44. +3
    -4
      mindspore/core/ops/gather_nd.cc
  45. +2
    -3
      mindspore/core/ops/gelu.cc
  46. +2
    -2
      mindspore/core/ops/grad/binary_cross_entropy_grad.cc
  47. +2
    -3
      mindspore/core/ops/grad/conv2d_backprop_filter.cc
  48. +1
    -1
      mindspore/core/ops/grad/conv2d_backprop_input.cc
  49. +2
    -2
      mindspore/core/ops/grad/dropout_grad.cc
  50. +3
    -8
      mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc
  51. +3
    -7
      mindspore/core/ops/grad/smooth_l1_loss_grad.cc
  52. +1
    -1
      mindspore/core/ops/hashtable_lookup.cc
  53. +2
    -2
      mindspore/core/ops/l2_normalize.cc
  54. +1
    -2
      mindspore/core/ops/leaky_relu.cc
  55. +1
    -1
      mindspore/core/ops/less.cc
  56. +1
    -2
      mindspore/core/ops/less_equal.cc
  57. +1
    -2
      mindspore/core/ops/local_response_normalization.cc
  58. +3
    -4
      mindspore/core/ops/log.cc
  59. +2
    -6
      mindspore/core/ops/logical_and.cc
  60. +2
    -6
      mindspore/core/ops/logical_not.cc
  61. +2
    -6
      mindspore/core/ops/logical_or.cc
  62. +2
    -3
      mindspore/core/ops/lrn.cc
  63. +1
    -2
      mindspore/core/ops/lsh_projection.cc
  64. +2
    -7
      mindspore/core/ops/mat_mul.cc
  65. +2
    -10
      mindspore/core/ops/matrix_diag.cc
  66. +1
    -2
      mindspore/core/ops/maximum.cc
  67. +5
    -8
      mindspore/core/ops/merge.cc
  68. +1
    -2
      mindspore/core/ops/minimum.cc
  69. +2
    -1
      mindspore/core/ops/neg.cc
  70. +1
    -1
      mindspore/core/ops/non_max_suppression.cc
  71. +5
    -11
      mindspore/core/ops/one_hot.cc
  72. +1
    -6
      mindspore/core/ops/ones_like.cc
  73. +5
    -7
      mindspore/core/ops/op_utils.h
  74. +2
    -4
      mindspore/core/ops/pad.cc
  75. +1
    -2
      mindspore/core/ops/pow.cc
  76. +4
    -7
      mindspore/core/ops/prelu.cc
  77. +1
    -1
      mindspore/core/ops/prior_box.cc
  78. +1
    -2
      mindspore/core/ops/range.cc
  79. +2
    -2
      mindspore/core/ops/rank.cc
  80. +1
    -2
      mindspore/core/ops/real_div.cc
  81. +2
    -3
      mindspore/core/ops/reciprocal.cc
  82. +2
    -4
      mindspore/core/ops/reduce.cc
  83. +3
    -4
      mindspore/core/ops/relu6.cc
  84. +3
    -6
      mindspore/core/ops/resize_bilinear.cc
  85. +6
    -7
      mindspore/core/ops/reverse_sequence.cc
  86. +3
    -15
      mindspore/core/ops/reverse_v2.cc
  87. +1
    -1
      mindspore/core/ops/rfft.cc
  88. +1
    -7
      mindspore/core/ops/round.cc
  89. +2
    -5
      mindspore/core/ops/rsqrt.cc
  90. +3
    -3
      mindspore/core/ops/scatter_nd.cc
  91. +3
    -7
      mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc
  92. +2
    -3
      mindspore/core/ops/sin.cc
  93. +2
    -3
      mindspore/core/ops/smooth_l1_loss.cc
  94. +2
    -5
      mindspore/core/ops/softmax.cc
  95. +2
    -3
      mindspore/core/ops/softmax_cross_entropy_with_logits.cc
  96. +1
    -2
      mindspore/core/ops/space_to_batch.cc
  97. +1
    -2
      mindspore/core/ops/space_to_batch_nd.cc
  98. +0
    -4
      mindspore/core/ops/sparse_to_dense.cc
  99. +2
    -3
      mindspore/core/ops/squared_difference.cc
  100. +1
    -2
      mindspore/core/ops/sub.cc

+ 1
- 2
mindspore/core/ops/abs.cc View File

@@ -46,8 +46,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 4
- 8
mindspore/core/ops/adam.cc View File

@@ -42,14 +42,10 @@ abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::ve
auto m_type = input_args[1]->BuildType();
auto v_type = input_args[2]->BuildType();
auto grad_type = input_args[9]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);

auto infer_var_type = var_type->cast<TensorTypePtr>()->element();
auto infer_m_type = m_type->cast<TensorTypePtr>()->element();
auto infer_v_type = v_type->cast<TensorTypePtr>()->element();
auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
// auto infer_grad_type = grad_type->cast<TensorTypePtr>()->element();
auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape);


+ 1
- 2
mindspore/core/ops/add.cc View File

@@ -40,8 +40,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 3
- 5
mindspore/core/ops/addn.cc View File

@@ -56,12 +56,10 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
}
types.emplace(elementi, elements[i]->BuildType());
}
std::set<TypeId> valid_types = common_valid_types;
valid_types.insert(kNumberTypeBool);
std::set<TypePtr> valid_types = common_valid_types;
valid_types.insert(kBool);
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);

return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
std::make_shared<abstract::Shape>(element0_shape));
return std::make_shared<abstract::AbstractTensor>(infer_type, std::make_shared<abstract::Shape>(element0_shape));
}
REGISTER_PRIMITIVE_C(kNameAddN, AddN);
} // namespace ops


+ 1
- 1
mindspore/core/ops/apply_momentum.cc View File

@@ -68,7 +68,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
auto l_type = input_args[2]->BuildType();
auto g_type = input_args[3]->BuildType();
auto m_type = input_args[4]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
std::map<std::string, TypePtr> args;


+ 0
- 3
mindspore/core/ops/arg_min.cc View File

@@ -62,9 +62,6 @@ AbstractBasePtr ArgMinInfer(const abstract::AnalysisEnginePtr &, const Primitive

// Infer type
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim_name);

return std::make_shared<abstract::AbstractTensor>(x_dtype, std::make_shared<abstract::Shape>(out_shape));
}
REGISTER_PRIMITIVE_C(kNameArgMin, ArgMin);


+ 2
- 8
mindspore/core/ops/asin.cc View File

@@ -36,14 +36,8 @@ AbstractBasePtr AsinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt

// Infer Type
auto dtype = input_args[0]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
auto tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));

const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
auto infer_type = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
}
REGISTER_PRIMITIVE_C(kNameAsin, Asin);


+ 4
- 4
mindspore/core/ops/assert.cc View File

@@ -61,15 +61,15 @@ AbstractBasePtr AssertInfer(const abstract::AnalysisEnginePtr &, const Primitive
condition = input_args[0]->BuildType();
}
std::vector<int64_t> output_shape = {1};
std::set<TypeId> local_bool = {kNumberTypeBool};
std::set<TypePtr> local_bool = {kBool};
std::map<std::string, TypePtr> args = {{"condition", condition}};
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, local_bool, op_name);
auto inputs_type = input_args[1]->BuildType()->cast<TuplePtr>()->elements();
for (auto dtype : inputs_type) {
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
std::set<TypePtr> template_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("input", dtype, template_types, op_name);
}
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), output_shape);
return std::make_shared<abstract::AbstractTensor>(kInt32, output_shape);
}
REGISTER_PRIMITIVE_C(kNameAssert, Assert);
} // namespace ops


+ 1
- 2
mindspore/core/ops/assign_add.cc View File

@@ -38,8 +38,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
// check_scalar_or_tensor_types_same
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, "AssignAdd");
}
} // namespace
AbstractBasePtr AssignAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 3
- 9
mindspore/core/ops/atan.cc View File

@@ -34,15 +34,9 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt

// Infer Type
auto dtype = input_args[0]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeInt32};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
auto tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));

return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kInt32};
auto element = CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(element, infer_shape->shape());
}
REGISTER_PRIMITIVE_C(kNameAtan, Atan);
} // namespace ops


+ 5
- 5
mindspore/core/ops/batch_norm.cc View File

@@ -107,20 +107,20 @@ AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const Primit
}

// Infer type
auto input_x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto scale_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
auto bias_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();

const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto input_x_type =
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
std::map<std::string, TypePtr> args;
args.emplace("scale", input_args[1]->BuildType());
args.emplace("bias", input_args[2]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
std::map<std::string, TypePtr> args_moving;
args_moving.emplace("scale", input_args[2]->BuildType());
args_moving.emplace("bias", input_args[3]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);

auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);


+ 2
- 17
mindspore/core/ops/batch_norm_fold.cc View File

@@ -87,23 +87,8 @@ AbstractBasePtr BatchNormFoldInfer(const abstract::AnalysisEnginePtr &, const Pr
auto global_step_type = input_args[3]->BuildType();

std::map<std::string, TypePtr> args = {{"x", x_type}, {"mean", mean_type}, {"variance", variance_type}};
CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kNumberTypeInt32}, op_name);

auto tensor_type0 = x_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type0);
auto element0 = tensor_type0->element();

auto tensor_type1 = mean_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type1);
auto element1 = tensor_type1->element();

auto tensor_type2 = variance_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type2);
auto element2 = tensor_type2->element();

CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "mean_type", element1->type_id(), op_name);
CheckAndConvertUtils::Check("input type", element0->type_id(), kEqual, "variance_type", element2->type_id(), op_name);
auto element0 = CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("gloabal_step", global_step_type, {kInt32}, op_name);

auto output = std::make_shared<abstract::AbstractTensor>(element0, mean_shape);
AbstractBasePtrList output1 = {output, output, output, output};


+ 2
- 1
mindspore/core/ops/batch_to_space.cc View File

@@ -54,7 +54,8 @@ AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types,
prim_name);

auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);


+ 1
- 2
mindspore/core/ops/bias_add.cc View File

@@ -55,8 +55,7 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
types.emplace("bias", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace
void BiasAdd::set_format(const Format &format) {


+ 2
- 2
mindspore/core/ops/binary_cross_entropy.cc View File

@@ -57,7 +57,7 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<A
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x_shape", input_args[0]->BuildType());
types.emplace("y_shape", input_args[1]->BuildType());
@@ -67,7 +67,7 @@ TypePtr BinaryCrossEntroyInferType(const PrimitivePtr &prim, const std::vector<A
types.emplace("weight_shape", input_args[2]->BuildType());
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
return TypeIdToType(infer_type);
return infer_type;
}
} // namespace



+ 1
- 1
mindspore/core/ops/broadcast.cc View File

@@ -56,7 +56,7 @@ AbstractBasePtr BroadcastInfer(const abstract::AnalysisEnginePtr &, const Primit
// infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::vector<TypePtr> output_types;
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
for (size_t i = 0; i < input_args.size(); i++) {
auto out_type = input_args[i]->BuildType()->cast<TensorTypePtr>()->element();
output_types.push_back(out_type);


+ 3
- 4
mindspore/core/ops/broadcast_to.cc View File

@@ -57,11 +57,10 @@ TypePtr BroadcastToInferType(const PrimitivePtr &prim, const std::vector<Abstrac
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> template_types = {TypeIdToType(kObjectTypeTensorType)};
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>();
std::set<TypePtr> template_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
auto infer_dtype = input_args[0]->BuildType()->type_id();
return TypeIdToType(infer_dtype);
return x_dtype->element();
}
} // namespace



+ 2
- 6
mindspore/core/ops/ceil.cc View File

@@ -33,13 +33,9 @@ AbstractBasePtr CeilInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), "Ceil");
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());
MS_EXCEPTION_IF_NULL(infer_type);
auto tensor_type = infer_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();
auto data_type = CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, primitive->name());
MS_EXCEPTION_IF_NULL(data_type);
return std::make_shared<abstract::AbstractTensor>(data_type, x_shape);
}


+ 1
- 3
mindspore/core/ops/concat.cc View File

@@ -74,9 +74,7 @@ AbstractBasePtr ConcatInfer(const abstract::AnalysisEnginePtr &, const Primitive
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, all_types, prim_name);
auto ret_shape = element0_shape;
ret_shape[axis] = all_shp;

return std::make_shared<abstract::AbstractTensor>(TypeIdToType(infer_type),
std::make_shared<abstract::Shape>(ret_shape));
return std::make_shared<abstract::AbstractTensor>(infer_type, std::make_shared<abstract::Shape>(ret_shape));
}
REGISTER_PRIMITIVE_C(kNameConcat, Concat);
} // namespace ops


+ 1
- 2
mindspore/core/ops/constant.cc View File

@@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 2
- 7
mindspore/core/ops/conv2d.cc View File

@@ -107,16 +107,11 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeFloat16,
kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeInt8) {
return TypeIdToType(kNumberTypeInt32);
}
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,


+ 2
- 3
mindspore/core/ops/conv2d_transpose.cc View File

@@ -40,12 +40,11 @@ TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<Abs
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("doutput_dtye", input_args[0]->BuildType());
types.emplace("w_dtype", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/cos.cc View File

@@ -40,8 +40,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 2
- 2
mindspore/core/ops/custom_extract_features.cc View File

@@ -31,8 +31,8 @@ AbstractBasePtr CustomExtractFeaturesInfer(const abstract::AnalysisEnginePtr &,
// auto input = input_args[0];

// Infer type
auto output0_type = TypeIdToType(kNumberTypeInt32);
auto output1_type = TypeIdToType(kNumberTypeFloat32);
auto output0_type = kInt32;
auto output1_type = kFloat32;

// Infer shape
std::vector<int64_t> out_shape;


+ 3
- 3
mindspore/core/ops/custom_predict.cc View File

@@ -47,14 +47,14 @@ AbstractBasePtr CustomPredictInfer(const abstract::AnalysisEnginePtr &, const Pr
MS_EXCEPTION_IF_NULL(primitive);
auto CustomPredict_prim = primitive->cast<PrimCustomPredictPtr>();
MS_EXCEPTION_IF_NULL(CustomPredict_prim);
for (auto input : input_args) {
for (const auto &input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
std::vector<int64_t> shape;
shape.push_back(CustomPredict_prim->get_output_num());

auto output0 = std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeFloat32), shape);
auto output0 = std::make_shared<abstract::AbstractTensor>(kInt32, shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
AbstractBasePtrList output = {output0, output1};
return std::make_shared<abstract::AbstractTuple>(output);
}


+ 3
- 3
mindspore/core/ops/depthwise_conv2d.cc View File

@@ -216,10 +216,10 @@ TypePtr DepthWiseConv2DInferType(const PrimitivePtr &prim, const std::vector<Abs
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
if (infer_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
if (*infer_type == *kInt8) {
return kInt32;
}
return TypeIdToType(infer_type);
return infer_type;
}

AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 1
- 1
mindspore/core/ops/detection_post_process.cc View File

@@ -157,7 +157,7 @@ AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, c
std::vector<int64_t> output_num_shape = {1};

// Infer type
auto output_type = TypeIdToType(kNumberTypeFloat32);
auto output_type = kFloat32;

auto output0 = std::make_shared<abstract::AbstractTensor>(output_type, output_boxes_shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(output_type, output_class_shape);


+ 1
- 2
mindspore/core/ops/div.cc View File

@@ -41,8 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 3
- 9
mindspore/core/ops/dropout.cc View File

@@ -53,15 +53,9 @@ AbstractBasePtr DropoutInfer(const abstract::AnalysisEnginePtr &, const Primitiv
auto infer_shape = std::make_shared<abstract::Shape>(out_shape);

// Infer type
auto dtype = input_args[0]->BuildType();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", dtype, valid_types, prim_name);
auto tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
auto infer_type = std::make_shared<TensorType>(TypeIdToType(element->type_id()));

const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto infer_type =
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", input_args[0]->BuildType(), valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape->shape());
}
REGISTER_PRIMITIVE_C(kNameDropout, Dropout);


+ 2
- 3
mindspore/core/ops/elu.cc View File

@@ -46,10 +46,9 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
void Elu::Init(const float alpha) { this->set_alpha(alpha); }


+ 3
- 8
mindspore/core/ops/embedding_lookup.cc View File

@@ -45,14 +45,9 @@ AbstractBasePtr EmbeddingLookupInfer(const abstract::AnalysisEnginePtr &, const
MS_EXCEPTION_IF_NULL(params);
auto indices = input_args[1]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(indices);
const std::set<TypeId> int_valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64};
CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices->BuildType(), int_valid_types, prim_name);
MS_EXCEPTION_IF_NULL(input_args[2]->BuildType());
auto offset_type = input_args[2]->BuildType()->type_id();
if (int_valid_types.find(offset_type) == int_valid_types.end()) {
MS_LOG(EXCEPTION) << "offset must be int.";
}

const std::set<TypePtr> int_valid_types = {kInt8, kInt16, kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices->BuildType(), int_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("offset", input_args[2]->BuildType(), int_valid_types, prim_name);
MS_EXCEPTION_IF_NULL(params->shape());
auto params_shp = params->shape()->shape();
MS_EXCEPTION_IF_NULL(indices->shape());


+ 1
- 2
mindspore/core/ops/equal.cc View File

@@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/exp.cc View File

@@ -39,8 +39,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace
AbstractBasePtr ExpInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 3
- 3
mindspore/core/ops/expand_dims.cc View File

@@ -50,10 +50,10 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
out_shape.insert(out_shape.begin() + dim_val, 1, 1);

// Infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)};
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
std::set<TypePtr> valid_x_type = {kTensorType};
CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name);
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);
return std::make_shared<abstract::AbstractTensor>(x_type->element(), out_shape);
}
REGISTER_PRIMITIVE_C(kNameExpandDims, ExpandDims);
} // namespace ops


+ 2
- 3
mindspore/core/ops/fake_quant_with_min_max_vars.cc View File

@@ -48,7 +48,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
@@ -56,8 +56,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
types.emplace("x", input_args[0]->BuildType());
types.emplace("min", input_args[1]->BuildType());
types.emplace("max", input_args[2]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
void FakeQuantWithMinMaxVars::Init(const bool narrow_range, const int64_t num_bits) {


+ 1
- 2
mindspore/core/ops/fake_quant_with_min_max_vars_per_channel.cc View File

@@ -60,8 +60,7 @@ AbstractBasePtr FakeQuantWithMinMaxVarsPerChannelInfer(const abstract::AnalysisE
std::vector<std::string> type_name = {"x", "min", "max"};
std::vector<TypePtr> type = {x_type, min_type, max_type};
for (int64_t i = 0; i < 3; i++) {
CheckAndConvertUtils::CheckTensorTypeValid(type_name[i], type[i], {kNumberTypeFloat16, kNumberTypeFloat32},
op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid(type_name[i], type[i], {kFloat16, kFloat32}, op_name);
}
auto tensor_type = x_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);


+ 1
- 1
mindspore/core/ops/fft_imag.cc View File

@@ -36,7 +36,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return TypeIdToType(kNumberTypeFloat32);
return kFloat32;
}
} // namespace



+ 2
- 2
mindspore/core/ops/fill.cc View File

@@ -37,8 +37,8 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
auto dtype = dtype_value->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(dtype);
auto valid_types = common_valid_types;
valid_types.insert(kNumberTypeBool);
CheckAndConvertUtils::CheckTypeSame("output datatype", dtype, valid_types, prim_name);
valid_types.insert(kBool);
(void)CheckAndConvertUtils::CheckTypeValid("output datatype", dtype, valid_types, prim_name);
auto out_shape = GetValue<std::vector<int64_t>>(input_args[1]->BuildValue());
auto x_type = input_args[2]->BuildType();
auto x_type_id = x_type->type_id();


+ 1
- 1
mindspore/core/ops/flatten.cc View File

@@ -42,7 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypePtr> valid_types = {TypeIdToType(kObjectTypeTensorType)};
const std::set<TypePtr> valid_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name());
return infer_type;
}


+ 3
- 4
mindspore/core/ops/floor.cc View File

@@ -39,14 +39,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
AbstractBasePtr FloorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 1
- 2
mindspore/core/ops/fusion/add_fusion.cc View File

@@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/fusion/pow_fusion.cc View File

@@ -50,8 +50,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 5
- 5
mindspore/core/ops/gather.cc View File

@@ -27,12 +27,12 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name);

// Infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
const std::set<TypeId> valid_index_types = {kNumberTypeInt32, kNumberTypeInt64};
std::set<TypePtr> valid_x_type = {kTensorType};
auto x_type =
CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
std::set<TypePtr> valid_index_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name);
std::set<TypePtr> valid_dim_type = {TypeIdToType(kNumberTypeInt32), TypeIdToType(kNumberTypeInt64)};
std::set<TypePtr> valid_dim_type = {kInt32, kInt64};
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name);

// Infer shape


+ 3
- 4
mindspore/core/ops/gather_nd.cc View File

@@ -52,14 +52,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64};
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64};
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
AbstractBasePtr GatherNdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 2
- 3
mindspore/core/ops/gelu.cc View File

@@ -39,11 +39,10 @@ TypePtr GeLUInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 2
- 2
mindspore/core/ops/grad/binary_cross_entropy_grad.cc View File

@@ -44,7 +44,7 @@ TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vect
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x_shape", input_args[0]->BuildType());
types.emplace("y_shape", input_args[1]->BuildType());
@@ -54,7 +54,7 @@ TypePtr BinaryCrossEntroyGradInferType(const PrimitivePtr &prim, const std::vect
types.emplace("weight_shape", input_args[2]->BuildType());
infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
return TypeIdToType(infer_type);
return infer_type;
}
} // namespace
void BinaryCrossEntropyGrad::Init(const Reduction &reduction) { set_reduction(reduction); }


+ 2
- 3
mindspore/core/ops/grad/conv2d_backprop_filter.cc View File

@@ -36,12 +36,11 @@ TypePtr Conv2DBackpropFilterInferType(const PrimitivePtr &prim, const std::vecto
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("drotput", input_args[0]->BuildType());
types.emplace("input_x", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 1
- 1
mindspore/core/ops/grad/conv2d_backprop_input.cc View File

@@ -28,7 +28,7 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
for (auto item : input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto doutput = input_args[0];


+ 2
- 2
mindspore/core/ops/grad/dropout_grad.cc View File

@@ -49,8 +49,8 @@ TypePtr DropoutGradInferType(const PrimitivePtr &prim, const std::vector<Abstrac
auto op_name = DropoutGrad_prim->name();
auto mask_dtype = input_args[1]->BuildType();
auto dy_dtype = input_args[0]->BuildType();
CheckAndConvertUtils::CheckSubClass("mask", mask_dtype, {TypeIdToType(kObjectTypeTensorType)}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("dy", dy_dtype, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("mask", mask_dtype, {kTensorType}, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("dy", dy_dtype, {kFloat16, kFloat32}, op_name);
auto tensor_type = dy_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();


+ 3
- 8
mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc View File

@@ -44,18 +44,13 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "dout_shape", dout_shape, prim_name, TypeError);

// Infer type
const std::set<TypeId> valid_types = {
kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16,
kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8,
kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat,
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64};
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
std::map<std::string, TypePtr> args;
args.emplace("x_type", input_args[0]->BuildType());
args.emplace("y_type", input_args[1]->BuildType());
args.emplace("dout_type", input_args[2]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
auto dout_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();

auto dout_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(dout_type, x_shape);
}
REGISTER_PRIMITIVE_C(kNameSigmoidCrossEntropyWithLogitsGrad, SigmoidCrossEntropyWithLogitsGrad);


+ 3
- 7
mindspore/core/ops/grad/smooth_l1_loss_grad.cc View File

@@ -49,17 +49,13 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError);

// Infer type
const std::set<TypeId> valid_types = {
kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16,
kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8,
kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat,
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64};
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
std::map<std::string, TypePtr> args;
args.emplace("prediction", input_args[0]->BuildType());
args.emplace("target", input_args[1]->BuildType());
args.emplace("dloss", input_args[2]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
auto dloss_type = input_args[2]->BuildType()->cast<TensorTypePtr>()->element();
auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);

return std::make_shared<abstract::AbstractTensor>(dloss_type, prediction);
}


+ 1
- 1
mindspore/core/ops/hashtable_lookup.cc View File

@@ -41,7 +41,7 @@ AbstractBasePtr HashtableLookupInfer(const abstract::AnalysisEnginePtr &, const
auto data_type = tensor_type->element();
std::vector<int64_t> value_shape;
auto output = std::make_shared<abstract::AbstractTensor>(data_type, value_shape);
auto hits = std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt8), hits_shape);
auto hits = std::make_shared<abstract::AbstractTensor>(kInt8, hits_shape);
AbstractBasePtrList output1 = {output, hits};

if (input_args[0]->BuildValue()->cast<tensor::TensorPtr>()->data_c() == nullptr) {


+ 2
- 2
mindspore/core/ops/l2_normalize.cc View File

@@ -49,8 +49,8 @@ AbstractBasePtr L2NormalizeInfer(const abstract::AnalysisEnginePtr &, const Prim
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto x_rank = SizeToLong(x_shape.size());
auto axiss = prim->get_axis();


+ 1
- 2
mindspore/core/ops/leaky_relu.cc View File

@@ -35,8 +35,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace
void LeakyRelu::Init(const float negative_slope) { this->set_negative_slope(negative_slope); }


+ 1
- 1
mindspore/core/ops/less.cc View File

@@ -41,7 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(kNumberTypeBool);
return kBool;
}
} // namespace



+ 1
- 2
mindspore/core/ops/less_equal.cc View File

@@ -41,8 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/local_response_normalization.cc View File

@@ -43,8 +43,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 3
- 4
mindspore/core/ops/log.cc View File

@@ -29,10 +29,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypePtr> valid_types = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name());
return infer_type;
const std::set<TypePtr> valid_types = {kTensorType};
return CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types,
prim->name());
}
} // namespace



+ 2
- 6
mindspore/core/ops/logical_and.cc View File

@@ -39,14 +39,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
const std::set<TypeId> valid_types = {kNumberTypeBool};
const std::set<TypePtr> valid_types = {kBool};
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeBool) {
return TypeIdToType(infer_type);
}
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeBool));
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 2
- 6
mindspore/core/ops/logical_not.cc View File

@@ -37,12 +37,8 @@ TypePtr LogicalNotInferType(const PrimitivePtr &prim, const std::vector<Abstract
MS_EXCEPTION_IF_NULL(LogicalNot_prim);
auto op_name = LogicalNot_prim->name();
auto infer_dtype = input_args[0]->BuildType();
std::set<TypeId> local_bool = {kNumberTypeBool};
CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name);
auto tensor_type = infer_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
return element;
std::set<TypePtr> local_bool = {kBool};
return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, local_bool, op_name);
}
} // namespace
AbstractBasePtr LogicalNotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 2
- 6
mindspore/core/ops/logical_or.cc View File

@@ -40,14 +40,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
const std::set<TypeId> valid_types = {kNumberTypeBool};
const std::set<TypePtr> valid_types = {kBool};
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeBool) {
return TypeIdToType(infer_type);
}
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeBool));
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 2
- 3
mindspore/core/ops/lrn.cc View File

@@ -86,14 +86,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/lsh_projection.cc View File

@@ -61,8 +61,7 @@ AbstractBasePtr LshProjectionInfer(const abstract::AnalysisEnginePtr &, const Pr
out_shape.push_back(input0[0] * input0[1]);
break;
}
TypePtr infer_type = TypeIdToType(kNumberTypeInt32);
return std::make_shared<abstract::AbstractTensor>(infer_type, out_shape);
return std::make_shared<abstract::AbstractTensor>(kInt32, out_shape);
}
REGISTER_PRIMITIVE_C(kNameLshProjection, LshProjection);
} // namespace ops


+ 2
- 7
mindspore/core/ops/mat_mul.cc View File

@@ -55,16 +55,11 @@ TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64,
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
}
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 2
- 10
mindspore/core/ops/matrix_diag.cc View File

@@ -59,19 +59,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeUInt8, kNumberTypeFloat16,
kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt8, kInt32, kUInt8, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("assist", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
auto type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(type);
auto tensor_type = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();
MS_EXCEPTION_IF_NULL(data_type);
return data_type;
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/maximum.cc View File

@@ -38,8 +38,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 5
- 8
mindspore/core/ops/merge.cc View File

@@ -38,16 +38,13 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
for (int64_t i = 0; i != (int64_t)inputs_type.size(); i++) {
args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]});
}
std::set<TypeId> template_type = {kNumberTypeBool};
for (auto item : common_valid_types) {
template_type.insert(item);
}
CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name);
std::set<TypePtr> template_type = common_valid_types;
template_type.emplace(kBool);
auto infered_type = CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name);
std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape();

auto output1 =
std::make_shared<abstract::AbstractTensor>(inputs_type[0]->cast<TensorTypePtr>()->element(), in_shape0);
auto output2 = std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), std::vector<int64_t>{1});
auto output1 = std::make_shared<abstract::AbstractTensor>(infered_type, in_shape0);
auto output2 = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{1});

AbstractBasePtrList output = {output1, output2};
return std::make_shared<abstract::AbstractTuple>(output);


+ 1
- 2
mindspore/core/ops/minimum.cc View File

@@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 2
- 1
mindspore/core/ops/neg.cc View File

@@ -31,7 +31,8 @@ AbstractBasePtr NegInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types,
prim_name);
return input_args[0]->Broaden();
}
REGISTER_PRIMITIVE_C(kNameNeg, Neg);


+ 1
- 1
mindspore/core/ops/non_max_suppression.cc View File

@@ -36,7 +36,7 @@ AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, cons
auto non_max_suppression_prim = primitive->cast<PrimNonMaxSuppressionPtr>();
MS_EXCEPTION_IF_NULL(non_max_suppression_prim);
MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime.";
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeInt32), std::vector<int64_t>{});
return std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{});
}
REGISTER_PRIMITIVE_C(kNameNonMaxSuppression, NonMaxSuppression);
} // namespace ops


+ 5
- 11
mindspore/core/ops/one_hot.cc View File

@@ -53,17 +53,11 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
auto OneHot_prim = prim->cast<PrimOneHotPtr>();
MS_EXCEPTION_IF_NULL(OneHot_prim);
auto op_name = OneHot_prim->name();
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kNumberTypeInt32}, op_name);
CheckAndConvertUtils::CheckTypeSame("depth", input_args[1]->BuildType(),
{kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64}, op_name);
auto value_type = input_args[2]->BuildType();
auto tensor_type = value_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
std::map<std::string, TypePtr> args = {{"on_value", value_type}, {"off_dtype", input_args[3]->BuildType()}};
CheckAndConvertUtils::CheckTensorTypeSame(args, {kNumberTypeFloat16, kNumberTypeFloat32}, op_name);
return element;
CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[0]->BuildType(), {kInt32}, op_name);
CheckAndConvertUtils::CheckTypeValid("depth", input_args[1]->BuildType(), {kInt8, kInt16, kInt32, kInt64}, op_name);
std::map<std::string, TypePtr> args = {{"on_value", input_args[2]->BuildType()},
{"off_dtype", input_args[3]->BuildType()}};
return CheckAndConvertUtils::CheckTensorTypeSame(args, {kFloat16, kFloat32}, op_name);
}
} // namespace
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 1
- 6
mindspore/core/ops/ones_like.cc View File

@@ -37,13 +37,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
// const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64,
// kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64,
// kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64,
// kNumberTypeBool};
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, common_valid_types, "OnesLike");
return infer_type;
return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, common_valid_types, "OnesLike");
}
} // namespace
AbstractBasePtr OnesLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 5
- 7
mindspore/core/ops/op_utils.h View File

@@ -230,14 +230,12 @@ constexpr auto kSpliceContext = "context";
constexpr auto kSpliceForwardIndexes = "forward_indexes";
constexpr auto kSpliceOutputDims = "output_dim";

const std::set<TypeId> common_valid_types = {
kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16,
kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};

const std::set<TypeId> all_types = {
kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64,
kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat,
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64,
const std::set<TypePtr> all_types = {
kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64,
};

abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args);


+ 2
- 4
mindspore/core/ops/pad.cc View File

@@ -49,10 +49,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {TypeIdToType(kObjectTypeTensorType)};
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckSubClass("infer type", infer_type, valid_types, prim->name());
return infer_type;
const std::set<TypePtr> valid_types = {kTensorType};
return CheckAndConvertUtils::CheckSubClass("infer type", input_args[0]->BuildType(), valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/pow.cc View File

@@ -37,8 +37,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 4
- 7
mindspore/core/ops/prelu.cc View File

@@ -46,13 +46,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), valid_types, prim->name());
CheckAndConvertUtils::CheckTensorTypeValid("weight", input_args[1]->BuildType(), valid_types, prim->name());
auto tensor_type = input_args[0]->BuildType()->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto input_x_type = tensor_type->element();
return input_x_type;
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<string, TypePtr> check_map = {{"input_x", input_args[0]->BuildType()},
{"weight", input_args[1]->BuildType()}};
return CheckAndConvertUtils::CheckTensorTypeSame(check_map, valid_types, prim->name());
}
} // namespace
AbstractBasePtr PReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 1
- 1
mindspore/core/ops/prior_box.cc View File

@@ -143,7 +143,7 @@ AbstractBasePtr PriorBoxInfer(const abstract::AnalysisEnginePtr &, const Primiti
auto input = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->BuildShape(), op_name);
int64_t h = input[0] * input[1] * num_priors_box * 4;
std::vector<int64_t> output_shape{1, h, 1, 2};
return std::make_shared<abstract::AbstractTensor>(TypeIdToType(kNumberTypeFloat32), output_shape);
return std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
}
REGISTER_PRIMITIVE_C(kNamePriorBox, PriorBox);
} // namespace ops


+ 1
- 2
mindspore/core/ops/range.cc View File

@@ -100,12 +100,11 @@ AbstractBasePtr RangeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
int64_t start = prim->get_start();
int64_t limit = prim->get_limit();
int64_t delta = prim->get_delta();
dtype = kNumberTypeInt32;
shape_size =
std::max(static_cast<int64_t>(std::ceil(LongToDouble(limit - start) / delta)), static_cast<int64_t>(0));
}
return std::make_shared<abstract::AbstractTensor>(
TypeIdToType(dtype), std::make_shared<abstract::Shape>(std::vector<int64_t>{shape_size}));
kInt32, std::make_shared<abstract::Shape>(std::vector<int64_t>{shape_size}));
}
REGISTER_PRIMITIVE_C(kNameRange, Range);
} // namespace ops


+ 2
- 2
mindspore/core/ops/rank.cc View File

@@ -25,8 +25,8 @@ TypePtr RankInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
MS_EXCEPTION_IF_NULL(Rank_prim);
auto op_name = Rank_prim->name();
auto infer_dtype = input_args[0]->BuildType();
CheckAndConvertUtils::CheckSubClass("x", infer_dtype, {TypeIdToType(kObjectTypeTensorType)}, op_name);
return TypeIdToType(kMetaTypeNone);
CheckAndConvertUtils::CheckTensorTypeValid("x", infer_dtype, {kTensorType}, op_name);
return kTypeNone;
}
} // namespace
AbstractBasePtr RankInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 1
- 2
mindspore/core/ops/real_div.cc View File

@@ -41,8 +41,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



+ 2
- 3
mindspore/core/ops/reciprocal.cc View File

@@ -39,9 +39,8 @@ AbstractBasePtr ReciprocalInfer(const abstract::AnalysisEnginePtr &, const Primi
auto in_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), prim_name);
// infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> valid_x_type = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name);
std::set<TypePtr> valid_x_type = {kTensorType};
auto x_type = CheckAndConvertUtils::CheckTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
return std::make_shared<abstract::AbstractTensor>(x_type, in_shape);
}
REGISTER_PRIMITIVE_C(kNameReciprocal, Reciprocal);


+ 2
- 4
mindspore/core/ops/reduce.cc View File

@@ -87,10 +87,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
std::map<std::string, TypePtr> types;
types.emplace("input_x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", input_args[0]->BuildType(), common_valid_types,
prim->name());
}
} // namespace



+ 3
- 4
mindspore/core/ops/relu6.cc View File

@@ -35,14 +35,13 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
AbstractBasePtr ReLU6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 3
- 6
mindspore/core/ops/resize_bilinear.cc View File

@@ -63,12 +63,9 @@ AbstractBasePtr ResizeBilinearInfer(const abstract::AnalysisEnginePtr &, const P
out_shape.insert(out_shape.end(), size.begin(), size.end());

// Infer type
auto input_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
CheckAndConvertUtils::CheckTensorTypeValid("input_type", input_type, valid_types, prim_name);
auto out_type = TypeIdToType(kNumberTypeFloat32);

return std::make_shared<abstract::AbstractTensor>(out_type, out_shape);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_type", input_args[0]->BuildType(), valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(kFloat32, out_shape);
}
REGISTER_PRIMITIVE_C(kNameResizeBilinear, ResizeBilinear);
} // namespace ops


+ 6
- 7
mindspore/core/ops/reverse_sequence.cc View File

@@ -62,15 +62,14 @@ AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const
CheckAndConvertUtils::CheckInteger("seq_lengths vector size", seq_lengths[0], kEqual, input_shape[batch_dim],
prim_name);
// infer type
std::set<TypeId> tmp(common_valid_types);
tmp.insert(kNumberTypeBool);
const std::set<TypeId> valid_x_types(tmp);
const std::set<TypeId> valid_seq_types = {kNumberTypeInt32, kNumberTypeInt64};
std::set<TypePtr> valid_x_types(common_valid_types);
valid_x_types.emplace(kBool);
const std::set<TypePtr> valid_seq_types = {kInt32, kInt64};
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto seq_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, valid_x_types, prim_name);
CheckAndConvertUtils::CheckTensorTypeValid("seq_type", seq_type, valid_seq_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(x_type, input_shape);
auto infered_type = CheckAndConvertUtils::CheckTensorTypeValid("x_type", x_type, valid_x_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("seq_type", seq_type, valid_seq_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(infered_type, input_shape);
}
REGISTER_PRIMITIVE_C(kNameReverseSequence, ReverseSequence);
} // namespace ops


+ 3
- 15
mindspore/core/ops/reverse_v2.cc View File

@@ -28,11 +28,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(reverseV2_prim);
auto prim_name = reverseV2_prim->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
// auto axis = reverseV2_prim->get_axis();
// int dim = x_shape.size();
// for (auto &axis_value : axis) {
// CheckAndConvertUtils::CheckInRange("axis value", axis_value, kIncludeLeft, {-dim, dim}, prim_name);
// }
return std::make_shared<abstract::Shape>(x_shape);
}

@@ -40,17 +35,10 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64,
kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64,
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool};
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kBool};
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, prim->name());
MS_EXCEPTION_IF_NULL(infer_type);
auto tensor_type = infer_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();
MS_EXCEPTION_IF_NULL(data_type);
return data_type;
return CheckAndConvertUtils::CheckTensorTypeValid("x type", infer_type, valid_types, prim->name());
}
} // namespace



+ 1
- 1
mindspore/core/ops/rfft.cc View File

@@ -39,7 +39,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return TypeIdToType(kNumberTypeComplex64);
return kComplex64;
}
} // namespace



+ 1
- 7
mindspore/core/ops/round.cc View File

@@ -29,13 +29,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = input_args[0]->BuildType();
CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, common_valid_types, prim->name());
MS_EXCEPTION_IF_NULL(infer_type);
auto tensor_type = infer_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto data_type = tensor_type->element();
MS_EXCEPTION_IF_NULL(data_type);
return data_type;
return CheckAndConvertUtils::CheckTensorTypeValid("x", infer_type, common_valid_types, prim->name());
}
} // namespace



+ 2
- 5
mindspore/core/ops/rsqrt.cc View File

@@ -38,13 +38,10 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), common_valid_types, prim->name());
}
} // namespace



+ 3
- 3
mindspore/core/ops/scatter_nd.cc View File

@@ -42,11 +42,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> indices_valid_types = {kNumberTypeInt32, kNumberTypeInt64};
const std::set<TypePtr> update_valid_types = {TypeIdToType(kObjectTypeTensorType)};
const std::set<TypePtr> indices_valid_types = {kInt32, kInt64};
const std::set<TypePtr> update_valid_types = {kTensorType};
auto indices_type = input_args[0]->BuildType();
auto update_type = input_args[1]->BuildType();
CheckAndConvertUtils::CheckSubClass("update type", update_type, update_valid_types, prim->name());
CheckAndConvertUtils::CheckTypeValid("update type", update_type, update_valid_types, prim->name());
CheckAndConvertUtils::CheckTensorTypeValid("indices type", indices_type, indices_valid_types, prim->name());
return input_args[1]->BuildType();
}


+ 3
- 7
mindspore/core/ops/sigmoid_cross_entropy_with_logits.cc View File

@@ -41,16 +41,12 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
CheckAndConvertUtils::Check("x_shape", x_shape, kEqual, "y_shape", y_shape, prim_name, TypeError);

// Infer type
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypeId> valid_types = {
kNumberTypeBool, kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16,
kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8,
kNumberTypeUInt16, kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat,
kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64};
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
std::map<std::string, TypePtr> args;
args.emplace("x_type", input_args[0]->BuildType());
args.emplace("y_type", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
auto x_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);

return std::make_shared<abstract::AbstractTensor>(x_type, x_shape);
}


+ 2
- 3
mindspore/core/ops/sin.cc View File

@@ -39,9 +39,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
CheckAndConvertUtils::CheckTensorTypeValid("x type", input_args[0]->BuildType(), common_valid_types, prim->name());
return infer_type;
return CheckAndConvertUtils::CheckTensorTypeValid("x type", input_args[0]->BuildType(), common_valid_types,
prim->name());
}
} // namespace



+ 2
- 3
mindspore/core/ops/smooth_l1_loss.cc View File

@@ -47,12 +47,11 @@ AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const Pri
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);

// Infer type
auto prediction_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> args;
args.emplace("scale", input_args[0]->BuildType());
args.emplace("bias", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);

return std::make_shared<abstract::AbstractTensor>(prediction_type, prediction);
}


+ 2
- 5
mindspore/core/ops/softmax.cc View File

@@ -62,11 +62,8 @@ TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, prim->name());
}

AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 2
- 3
mindspore/core/ops/softmax_cross_entropy_with_logits.cc View File

@@ -46,12 +46,11 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
auto dlogits_shape = logits_shape;

// Infer type
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> args;
args.emplace("logits_type", input_args[0]->BuildType());
args.emplace("labels_type", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
auto logits_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto logits_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);

auto output0 = std::make_shared<abstract::AbstractTensor>(logits_type, loss_shape);
auto output1 = std::make_shared<abstract::AbstractTensor>(logits_type, dlogits_shape);


+ 1
- 2
mindspore/core/ops/space_to_batch.cc View File

@@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
}
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace
void SpaceToBatch::set_paddings(const std::vector<std::vector<int64_t>> &paddings) {


+ 1
- 2
mindspore/core/ops/space_to_batch_nd.cc View File

@@ -56,8 +56,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
return infer_type;
return input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
}
} // namespace



+ 0
- 4
mindspore/core/ops/sparse_to_dense.cc View File

@@ -38,11 +38,7 @@ AbstractBasePtr SparseToDenseInfer(const abstract::AnalysisEnginePtr &, const Pr
auto dense_shape =
CheckAndConvertUtils::ConvertShapePtrToShape("dense_shape", input_args[3]->BuildShape(), prim_name);
// infer type
auto indices_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
auto values_type = input_args[1]->BuildType()->cast<TensorTypePtr>()->element();
std::set<TypePtr> valid_type = {TypeIdToType(kObjectTypeTensorType)};
CheckAndConvertUtils::CheckSubClass("indices_type", indices_type, valid_type, prim_name);
CheckAndConvertUtils::CheckSubClass("values_type", values_type, valid_type, prim_name);
return std::make_shared<abstract::AbstractTensor>(values_type, dense_shape);
}
REGISTER_PRIMITIVE_C(kNameSparseToDense, SparseToDense);


+ 2
- 3
mindspore/core/ops/squared_difference.cc View File

@@ -37,12 +37,11 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<TypePtr> valid_types = {kInt32, kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace



+ 1
- 2
mindspore/core/ops/sub.cc View File

@@ -42,8 +42,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("y", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return TypeIdToType(infer_type);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
}
} // namespace



Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save