Browse Source

codedex_clean 0526

tags/v1.3.0
dingpeifei 4 years ago
parent
commit
1f68cf175e
18 changed files with 46 additions and 50 deletions
  1. +1
    -1
      mindspore/core/ops/apply_momentum.cc
  2. +1
    -1
      mindspore/core/ops/atan.cc
  3. +5
    -6
      mindspore/core/ops/avg_pool.cc
  4. +5
    -6
      mindspore/core/ops/batch_to_space_nd.cc
  5. +6
    -6
      mindspore/core/ops/conv2d_transpose.cc
  6. +2
    -2
      mindspore/core/ops/depth_to_space.cc
  7. +1
    -1
      mindspore/core/ops/fill.cc
  8. +2
    -2
      mindspore/core/ops/gather_nd.cc
  9. +1
    -1
      mindspore/core/ops/grad/conv2d_backprop_input.cc
  10. +2
    -2
      mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc
  11. +4
    -5
      mindspore/core/ops/mfcc.cc
  12. +1
    -1
      mindspore/core/ops/non_max_suppression.cc
  13. +1
    -1
      mindspore/core/ops/roi_pooling.cc
  14. +2
    -3
      mindspore/core/ops/scatter_nd.cc
  15. +2
    -3
      mindspore/core/ops/skip_gram.cc
  16. +2
    -2
      mindspore/core/ops/softmax_cross_entropy_with_logits.cc
  17. +3
    -3
      mindspore/core/ops/stack.cc
  18. +5
    -4
      mindspore/core/ops/unsorted_segment_sum.cc

+ 1
- 1
mindspore/core/ops/apply_momentum.cc View File

@@ -57,7 +57,7 @@ AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const Pr
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", input_args.size(), kEqual, 5, prim_name);
CheckAndConvertUtils::CheckInteger("apply_momentum_infer", SizeToLong(input_args.size()), kEqual, 5, prim_name);

// Infer shape
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];


+ 1
- 1
mindspore/core/ops/atan.cc View File

@@ -24,7 +24,7 @@ AbstractBasePtr AtanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("Atan_infer", input_args.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("Atan_infer", int64_t(input_args.size()), kEqual, 1, prim_name);

// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];


+ 5
- 6
mindspore/core/ops/avg_pool.cc View File

@@ -87,7 +87,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
if (format == NHWC) {
in_shape = {in_shape[0], in_shape[3], in_shape[1], in_shape[2]};
}
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
CheckAndConvertUtils::CheckInteger("x_rank", SizeToLong(in_shape.size()), kEqual, 4, op_name);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
auto batch = in_shape[0];
@@ -112,14 +112,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
if (format == NHWC) {
out_shape = {batch, out_h, out_w, channel};
}
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t arg) { return arg <= 0; })) {
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
}
return std::make_shared<abstract::Shape>(out_shape);
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
return input_args[0]->BuildType();
@@ -128,8 +128,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &

AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameAvgPool, AvgPool);
} // namespace ops


+ 5
- 6
mindspore/core/ops/batch_to_space_nd.cc View File

@@ -30,7 +30,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("input_x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
auto out_shape = x_shape;
int64_t block_shape_prod = 1;
int64_t offset = 2;
@@ -52,7 +52,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
return std::make_shared<abstract::Shape>(out_shape);
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@@ -62,7 +62,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace

void BatchToSpaceND::set_crops(std::vector<std::vector<int64_t>> crops) {
CheckAndConvertUtils::CheckInteger(kCrops, crops.size(), kEqual, 2, this->name());
CheckAndConvertUtils::CheckInteger(kCrops, SizeToLong(crops.size()), kEqual, 2, this->name());
int64_t h = crops.size();
int64_t w = crops[0].size();
std::vector<int64_t> temp_w = {2, 2};
@@ -80,7 +80,7 @@ std::vector<std::vector<int64_t>> BatchToSpaceND::get_crops() const {
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
void BatchToSpaceND::set_block_shape(std::vector<int64_t> block_shape) {
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape.size(), kEqual, 2, this->name());
CheckAndConvertUtils::CheckInteger(kBlockShape, SizeToLong(block_shape.size()), kEqual, 2, this->name());
for (int64_t i = 0; i < (int64_t)block_shape.size(); i++) {
CheckAndConvertUtils::CheckInteger(kBlockShape, block_shape[i], kGreaterEqual, 1, this->name());
}
@@ -98,8 +98,7 @@ void BatchToSpaceND::Init(std::vector<int64_t> block_shape, std::vector<std::vec
}
AbstractBasePtr BatchToSpaceNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameBatchToSpaceND, BatchToSpaceND);
} // namespace ops


+ 6
- 6
mindspore/core/ops/conv2d_transpose.cc View File

@@ -33,7 +33,7 @@ abstract::ShapePtr Conv2dTransposeInferShape(const PrimitivePtr &primitive,
}

TypePtr Conv2dTransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", input_args.size(), kEqual, 3, prim->name());
CheckAndConvertUtils::CheckInteger("conv2d_transpose_infer", SizeToLong(input_args.size()), kEqual, 3, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@@ -72,7 +72,7 @@ void Conv2dTranspose::set_out_channel(int64_t out_channel) {
}

void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
CheckAndConvertUtils::CheckInteger(kKernelSize, kernel_size.size(), kEqual, 2, name());
CheckAndConvertUtils::CheckInteger(kKernelSize, SizeToLong(kernel_size.size()), kEqual, 2, name());
for (int64_t item : kernel_size) {
CheckAndConvertUtils::CheckInteger(kKernelSize, item, kGreaterEqual, 1, name());
}
@@ -80,7 +80,7 @@ void Conv2dTranspose::set_kernel_size(const std::vector<int64_t> &kernel_size) {
}

void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
CheckAndConvertUtils::CheckInteger(kStride, stride.size(), kEqual, 2, name());
CheckAndConvertUtils::CheckInteger(kStride, SizeToLong(stride.size()), kEqual, 2, name());
for (int64_t item : stride) {
CheckAndConvertUtils::CheckInteger(kStride, item, kGreaterEqual, 1, name());
}
@@ -88,7 +88,7 @@ void Conv2dTranspose::set_stride(const std::vector<int64_t> &stride) {
}

void Conv2dTranspose::set_dilation(const std::vector<int64_t> &dilation) {
CheckAndConvertUtils::CheckInteger(kDilation, dilation.size(), kGreaterEqual, 2, name());
CheckAndConvertUtils::CheckInteger(kDilation, SizeToLong(dilation.size()), kGreaterEqual, 2, name());
AddAttr(kDilation, MakeValue(dilation));
}

@@ -106,7 +106,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) {
}

void Conv2dTranspose::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}

@@ -124,7 +124,7 @@ void Conv2dTranspose::set_format(const Format &format) {
}

void Conv2dTranspose::set_pad_list(const std::vector<int64_t> &pad_list) {
CheckAndConvertUtils::CheckInteger(kPadList, pad_list.size(), kEqual, 4, name());
CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, 4, name());
this->AddAttr(kPadList, MakeValue(pad_list));
}



+ 2
- 2
mindspore/core/ops/depth_to_space.cc View File

@@ -47,7 +47,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("input number", int64_t(input_args.size()), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@@ -59,7 +59,7 @@ AbstractBasePtr DepthToSpaceInfer(const abstract::AnalysisEnginePtr &, const Pri
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
}
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", SizeToLong(x_shape.size()), kEqual, 4, prim_name);
int64_t block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
CheckAndConvertUtils::CheckInteger("x_shape[1] % (block_size*block_size)", x_shape[1] % (block_size * block_size),
kEqual, 0, prim_name);


+ 1
- 1
mindspore/core/ops/fill.cc View File

@@ -26,7 +26,7 @@ AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


+ 2
- 2
mindspore/core/ops/gather_nd.cc View File

@@ -28,7 +28,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@@ -50,7 +50,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64};
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &arg) { return arg == nullptr; })) {
MS_LOG(EXCEPTION) << "nullptr";
}
std::map<std::string, TypePtr> types;


+ 1
- 1
mindspore/core/ops/grad/conv2d_backprop_input.cc View File

@@ -147,7 +147,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) {
}

void Conv2DBackpropInput::set_pad(const std::vector<int64_t> &pad) {
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
CheckAndConvertUtils::CheckInteger("pad_size", SizeToLong(pad.size()), kEqual, 4, name());
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
}



+ 2
- 2
mindspore/core/ops/grad/sigmoid_cross_entropy_with_logits_grad.cc View File

@@ -31,8 +31,8 @@ AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisE
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", input_args.size(), kEqual, 3,
prim_name);
CheckAndConvertUtils::CheckInteger("sigmoid_cross_entropy_with_logits_grad_infer", SizeToLong(input_args.size()),
kEqual, 3, prim_name);

// Infer Shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];


+ 4
- 5
mindspore/core/ops/mfcc.cc View File

@@ -27,14 +27,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto prim_name = primitive->name();
auto first_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto second_input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("first input rank", first_input_shape.size(), kEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("second input rank", second_input_shape.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("first input rank", SizeToLong(first_input_shape.size()), kEqual, 3, prim_name);
CheckAndConvertUtils::CheckInteger("second input rank", SizeToLong(second_input_shape.size()), kEqual, 1, prim_name);
std::vector<int64_t> out_shape = {first_input_shape[0], first_input_shape[1],
GetValue<int64_t>(primitive->GetAttr(kDctCoeffNum))};
return std::make_shared<abstract::Shape>(out_shape);
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@@ -84,8 +84,7 @@ int64_t Mfcc::get_dct_coeff_num() const { return GetValue<int64_t>(GetAttr(kDctC

AbstractBasePtr MfccInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameMfcc, Mfcc);
} // namespace ops


+ 1
- 1
mindspore/core/ops/non_max_suppression.cc View File

@@ -29,7 +29,7 @@ int64_t NonMaxSuppression::get_center_point_box() const {
}
void NonMaxSuppression::Init(const int64_t center_point_box) { this->set_center_point_box(center_point_box); }

AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr NonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &,
const std::vector<AbstractBasePtr> &input_args) {
MS_LOG(INFO) << "NonMaxSuppression infer shape in runtime.";
return std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>{});


+ 1
- 1
mindspore/core/ops/roi_pooling.cc View File

@@ -52,7 +52,7 @@ AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const Primi
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("roi_pooling_infer", input_args.size(), kEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("roi_pooling_infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
MS_EXCEPTION_IF_NULL(input_args[1]);



+ 2
- 3
mindspore/core/ops/scatter_nd.cc View File

@@ -23,7 +23,7 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
abstract::ShapePtr InferShape(const std::vector<AbstractBasePtr> &input_args) {
auto shape_value = input_args[2]->BuildValue();
auto shape_value_element = GetValue<std::vector<int64_t>>(shape_value);
for (const auto &shape : shape_value_element) {
@@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &

AbstractBasePtr ScatterNdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args), InferShape(input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameScatterNd, ScatterNd);
} // namespace ops


+ 2
- 3
mindspore/core/ops/skip_gram.cc View File

@@ -34,7 +34,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
return std::make_shared<abstract::Shape>(in_shape);
}

TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
TypePtr InferType(const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = input_args[0]->BuildType();
return infer_type;
}
@@ -65,8 +65,7 @@ void SkipGram::Init(const bool include_all_grams, const int64_t max_skip_size, c

AbstractBasePtr SkipGramInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
return std::make_shared<abstract::AbstractTensor>(InferType(input_args), InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameSkipGram, SkipGram);
} // namespace ops


+ 2
- 2
mindspore/core/ops/softmax_cross_entropy_with_logits.cc View File

@@ -30,8 +30,8 @@ AbstractBasePtr SoftmaxCrossEntropyWithLogitsInfer(const abstract::AnalysisEngin
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", input_args.size(), kEqual, 2,
prim_name);
CheckAndConvertUtils::CheckInteger("softmax_cross_entropy_with_logics_infer", SizeToLong(input_args.size()), kEqual,
2, prim_name);

// Infer shape
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];


+ 3
- 3
mindspore/core/ops/stack.cc View File

@@ -29,12 +29,12 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
MS_LOG(ERROR) << "Invalid input size " << input_args.size();
}
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
for (int64_t i = 1; i < (int64_t)input_args.size(); ++i) {
for (int64_t i = 1; i < SizeToLong(input_args.size()); ++i) {
auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
if (input_shape_tmp.size() != input_shape.size()) {
MS_LOG(ERROR) << "All input shape size should be the same!";
}
for (int64_t j = 0; j < (int64_t)input_shape.size(); ++j) {
for (int64_t j = 0; j < SizeToLong(input_shape.size()); ++j) {
if (input_shape_tmp.at(j) != input_shape.at(j)) {
MS_LOG(ERROR) << "All input shape should be the same!";
}
@@ -44,7 +44,7 @@ abstract::AbstractBasePtr StackInfer(const PrimitivePtr &primitive, const std::v
infer_shape.insert(infer_shape.begin() + GetValue<int64_t>(primitive->GetAttr(kAxis)), input_args.size());

auto infer_type0 = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
for (int64_t i = 1; i < (int64_t)input_args.size(); i++) {
for (int64_t i = 1; i < SizeToLong(input_args.size()); i++) {
if (input_args[i]->BuildType()->cast<TensorTypePtr>()->element() == infer_type0) {
MS_LOG(ERROR) << "All input should have the same data type!input[" << i
<< "] data type = " << input_args[i]->BuildType()->cast<TensorTypePtr>()->element();


+ 5
- 4
mindspore/core/ops/unsorted_segment_sum.cc View File

@@ -34,12 +34,13 @@ AbstractBasePtr UnsortedSegmentSumInfer(const abstract::AnalysisEnginePtr &, con
auto x_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
// Infer shape
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("x_shape", x_shape.size(), kGreaterThan, 0, prim_name);
CheckAndConvertUtils::CheckInteger("x_shape", SizeToLong(x_shape.size()), kGreaterThan, 0, prim_name);
auto shp = x_shape;
auto segment_ids_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kGreaterThan, 0, prim_name);
CheckAndConvertUtils::Check("input_x", x_shape.size(), kGreaterEqual, "segment_ids_shape", segment_ids_shape.size(),
prim_name);
CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kGreaterThan, 0,
prim_name);
CheckAndConvertUtils::Check("input_x", int64_t(x_shape.size()), kGreaterEqual, "segment_ids_shape",
int64_t(segment_ids_shape.size()), prim_name);

if ((x_shape.end() != find(x_shape.begin(), x_shape.end(), -1)) &&
(segment_ids_shape.end() != find(segment_ids_shape.begin(), segment_ids_shape.end(), -1))) {


Loading…
Cancel
Save