| @@ -57,7 +57,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name); | |||
| // Infer shape | |||
| auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| @@ -24,7 +24,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("Atan_infer", int64_t(input_args.size()), kEqual, 1, prim_name); | |||
| // Infer Shape | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| @@ -87,7 +87,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| if (format == NHWC) { | |||
| in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name); | |||
| CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name); | |||
| auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)); | |||
| auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode))); | |||
| auto batch = in_shape[0]; | |||
| @@ -112,14 +112,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| if (format == NHWC) { | |||
| out_shape = {batch, out_h, out_w, channel}; | |||
| } | |||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) { | |||
| if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t arg) { return arg <= 0; })) { | |||
| MS_LOG(EXCEPTION) << "Kernel size is not valid."; | |||
| } | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) { | |||
| TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| return input_args[0]->BuildType(); | |||
| @@ -128,8 +128,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & | |||
| AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool); | |||
| } // namespace ops | |||
| @@ -30,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name); | |||
| auto out_shape = x_shape; | |||
| int64_t block_shape_prod = 1; | |||
| int64_t offset = 2; | |||
| @@ -52,7 +52,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| @@ -62,7 +62,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & | |||
| } // namespace | |||
| void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) { | |||
| CheckAndConvertUtils::CheckInteger(kCrops, crops.size(), kEqual, 2, this->name()); | |||
| CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, 2, this->name()); | |||
| int64_t h = crops.size(); | |||
| int64_t w = crops[0].size(); | |||
| std::vector<int64_t> temp_w = {2, 2}; | |||
| @@ -80,7 +80,7 @@ std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const { | |||
| return GetValue<std::vector<std::vector<int64_t>>>(value_ptr); | |||
| } | |||
| void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) { | |||
| CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name()); | |||
| CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name()); | |||
| for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) { | |||
| CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name()); | |||
| } | |||
| @@ -98,8 +98,7 @@ void BatchToSpaceND::Init(std::vector<int64_t> block_shape, std::vector<std::vec | |||
| } | |||
| AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameBatchToSpaceND, BatchToSpaceND); | |||
| } // namespace ops | |||
| @@ -33,7 +33,7 @@ abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive, | |||
| } | |||
| TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", input_args.size(), kEqual, 3, prim->name()); | |||
| CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", SizeToLong(input_args.size()), kEqual, 3, prim->name()); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| @@ -72,7 +72,7 @@ void Conv2dTranspose::set_out_channel(int64_t out_channel) { | |||
| } | |||
| void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||
| CheckAndConvertUtils::CheckInteger(kKernelSize, kernel_size.size(), kEqual, 2, name()); | |||
| CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name()); | |||
| for (int64_t item : kernel_size) { | |||
| CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name()); | |||
| } | |||
| @@ -80,7 +80,7 @@ void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) { | |||
| } | |||
| void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) { | |||
| CheckAndConvertUtils::CheckInteger(kStride, stride.size(), kEqual, 2, name()); | |||
| CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name()); | |||
| for (int64_t item : stride) { | |||
| CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name()); | |||
| } | |||
| @@ -88,7 +88,7 @@ void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) { | |||
| } | |||
| void Conv2dTranspose::set_dilation(const std::vector<int64_t> &dilation) { | |||
| CheckAndConvertUtils::CheckInteger(kDilation, dilation.size(), kGreaterEqual, 2, name()); | |||
| CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name()); | |||
| AddAttr(kDilation, MakeValue(dilation)); | |||
| } | |||
| @@ -106,7 +106,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) { | |||
| } | |||
| void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) { | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | |||
| CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name()); | |||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); | |||
| } | |||
| @@ -124,7 +124,7 @@ void Conv2dTranspose::set_format(const Format &format) { | |||
| } | |||
| void Conv2dTranspose::set_pad_list(const std::vector<int64_t> &pad_list) { | |||
| CheckAndConvertUtils::CheckInteger(kPadList, pad_list.size(), kEqual, 4, name()); | |||
| CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name()); | |||
| this->AddAttr(kPadList, MakeValue(pad_list)); | |||
| } | |||
| @@ -47,7 +47,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| @@ -59,7 +59,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri | |||
| if (format == NHWC) { | |||
| x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; | |||
| } | |||
| CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name); | |||
| int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize)); | |||
| CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size), | |||
| kEqual, 0, prim_name); | |||
| @@ -26,7 +26,7 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| @@ -28,7 +28,7 @@ namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| @@ -50,7 +50,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64}; | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { | |||
| if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &arg) { return arg == nullptr; })) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| } | |||
| std::map<std::string, TypePtr> types; | |||
| @@ -147,7 +147,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { | |||
| } | |||
| void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) { | |||
| CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); | |||
| CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name()); | |||
| AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); | |||
| } | |||
| @@ -31,8 +31,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", input_args.size(), kEqual, 3, | |||
| prim_name); | |||
| CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", SizeToLong(input_args.size()), | |||
| kEqual, 3, prim_name); | |||
| // Infer Shape | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| @@ -27,14 +27,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| auto prim_name = primitive->name(); | |||
| auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||
| CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("first input rank", SizeToLong(first_input_shape.size()), kEqual, 3, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("second input rank", SizeToLong(second_input_shape.size()), kEqual, 1, prim_name); | |||
| std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1], | |||
| GetValue<int64_t>(primitive->GetAttr(kDctCoeffNum))}; | |||
| return std::make_shared<abstract::Shape>(out_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| @@ -84,8 +84,7 @@ int64_t Mfcc::get_dct_coeff_num() const { return GetValue<int64_t>(GetAttr(kDctC | |||
| AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameMfcc, Mfcc); | |||
| } // namespace ops | |||
| @@ -29,7 +29,7 @@ int64_t NonMaxSuppression::get_center_point_box() const { | |||
| } | |||
| void NonMaxSuppression::Init(const int64_t center_point_box) { this->set_center_point_box(center_point_box); } | |||
| AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime."; | |||
| return std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{}); | |||
| @@ -52,7 +52,7 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("roi_pooling_infer", input_args.size(), kEqual, 2, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("roi_pooling_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| MS_EXCEPTION_IF_NULL(input_args[1]); | |||
| @@ -23,7 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) { | |||
| auto shape_value = input_args[2]->BuildValue(); | |||
| auto shape_value_element = GetValue<std::vector<int64_t>>(shape_value); | |||
| for (const auto &shape : shape_value_element) { | |||
| @@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> & | |||
| AbstractBasePtr ScatterNdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameScatterNd, ScatterNd); | |||
| } // namespace ops | |||
| @@ -34,7 +34,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A | |||
| return std::make_shared<abstract::Shape>(in_shape); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { | |||
| TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) { | |||
| auto infer_type = input_args[0]->BuildType(); | |||
| return infer_type; | |||
| } | |||
| @@ -65,8 +65,7 @@ void SkipGram::Init(const bool include_all_grams, const int64_t max_skip_size, c | |||
| AbstractBasePtr SkipGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), | |||
| InferShape(primitive, input_args)->shape()); | |||
| return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape()); | |||
| } | |||
| REGISTER_PRIMITIVE_C(kNameSkipGram, SkipGram); | |||
| } // namespace ops | |||
| @@ -30,8 +30,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", input_args.size(), kEqual, 2, | |||
| prim_name); | |||
| CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", SizeToLong(input_args.size()), kEqual, | |||
| 2, prim_name); | |||
| // Infer shape | |||
| auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| @@ -29,12 +29,12 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v | |||
| MS_LOG(ERROR) << "Invalid input size " << input_args.size(); | |||
| } | |||
| auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) { | |||
| for (int64_t i = 1; i < SizeToLong(input_args.size()); ++i) { | |||
| auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; | |||
| if (input_shape_tmp.size() != input_shape.size()) { | |||
| MS_LOG(ERROR) << "All input shape size should be the same!"; | |||
| } | |||
| for (int64_t j = 0; j < (int64_t)input_shape.size(); ++j) { | |||
| for (int64_t j = 0; j < SizeToLong(input_shape.size()); ++j) { | |||
| if (input_shape_tmp.at(j) != input_shape.at(j)) { | |||
| MS_LOG(ERROR) << "All input shape should be the same!"; | |||
| } | |||
| @@ -44,7 +44,7 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v | |||
| infer_shape.insert(infer_shape.begin() + GetValue<int64_t>(primitive->GetAttr(kAxis)), input_args.size()); | |||
| auto infer_type0 = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| for (int64_t i = 1; i < (int64_t)input_args.size(); i++) { | |||
| for (int64_t i = 1; i < SizeToLong(input_args.size()); i++) { | |||
| if (input_args[i]->BuildType()->cast<TensorTypePtr>()->element() == infer_type0) { | |||
| MS_LOG(ERROR) << "All input should have the same data type!input[" << i | |||
| << "] data type = " << input_args[i]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| @@ -34,12 +34,13 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con | |||
| auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element(); | |||
| // Infer shape | |||
| auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterThan, 0, prim_name); | |||
| CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kGreaterThan, 0, prim_name); | |||
| auto shp = x_shape; | |||
| auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||
| CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kGreaterThan, 0, prim_name); | |||
| CheckAndConvertUtils::Check("input_x", x_shape.size(), kGreaterEqual, "segment_ids_shape", segment_ids_shape.size(), | |||
| prim_name); | |||
| CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kGreaterThan, 0, | |||
| prim_name); | |||
| CheckAndConvertUtils::Check("input_x", int64_t(x_shape.size()), kGreaterEqual, "segment_ids_shape", | |||
| int64_t(segment_ids_shape.size()), prim_name); | |||
| if ((x_shape.end() != find(x_shape.begin(), x_shape.end(), -1)) && | |||
| (segment_ids_shape.end() != find(segment_ids_shape.begin(), segment_ids_shape.end(), -1))) { | |||