diff --git a/mindspore/lite/nnacl/base/arithmetic_base.c b/mindspore/lite/nnacl/base/arithmetic_base.c index 49cda36afe..5a04ccc07d 100644 --- a/mindspore/lite/nnacl/base/arithmetic_base.c +++ b/mindspore/lite/nnacl/base/arithmetic_base.c @@ -20,8 +20,12 @@ void CalcMultiplesAndStrides(ArithmeticParameter *param) { NNACL_ASSERT(param->in_shape0_[i] != 0); NNACL_ASSERT(param->in_shape1_[i] != 0); for (size_t i = 0; i < param->ndim_; i++) { - param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; - param->multiples1_[i] = param->out_shape_[i] / param->in_shape1_[i]; + if (param->in_shape0_[i] != 0) { + param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; + } + if (param->in_shape1_[i] != 0) { + param->multiples1_[i] = param->out_shape_[i] / param->in_shape1_[i]; + } } // cal strides ComputeStrides(param->in_shape0_, param->in_strides0_, param->ndim_); diff --git a/mindspore/lite/nnacl/fp32/resize_fp32.c b/mindspore/lite/nnacl/fp32/resize_fp32.c index 2cb73340a7..1a4317a183 100644 --- a/mindspore/lite/nnacl/fp32/resize_fp32.c +++ b/mindspore/lite/nnacl/fp32/resize_fp32.c @@ -63,14 +63,13 @@ int PrepareCropAndResizeBilinear(const int *input_shape, const float *boxes, con int new_height = output_shape[1]; int new_width = output_shape[2]; - for (int i = 0; i < batch; i++) { - int b = box_idx[i]; + for (int b = 0; b < batch; b++) { const float *box = boxes + b * 4; - int start_h = box[0] * (in_h - 1); - int end_h = box[2] * (in_h - 1); - int start_w = box[1] * (in_w - 1); - int end_w = box[3] * (in_w - 1); - if (start_h >= end_h || start_w >= end_w || end_h >= in_h || end_w >= in_w) { + float start_h = box[0]; + float end_h = box[2]; + float start_w = box[1]; + float end_w = box[3]; + if (start_h > end_h || start_w > end_w || end_h > 1 || end_w > 1) { return NNACL_PARAM_INVALID; } diff --git a/mindspore/lite/nnacl/where.c b/mindspore/lite/nnacl/where.c index 666a9a79c7..dad83be8d8 100644 --- a/mindspore/lite/nnacl/where.c +++ b/mindspore/lite/nnacl/where.c @@ -14,14 +14,15 @@ * limitations under the License. */ #include "nnacl/where.h" +#include "nnacl/common_func.h" -void Where(bool *input, const float *input1, const float *input2, float *output, WhereParameter *where_param_, - int task_id) { - for (int i = task_id; i < where_param_->number_; i += where_param_->op_parameter_.thread_num_) { - if (input[where_param_->num_ > 1 ? i : 0] == true) { - output[i] = input1[where_param_->num1_ > 1 ? i : 0]; +void WhereWithTripleInputs(const bool *condition, const float *x, const float *y, float *output, + WhereParameter *where_param_, int task_id) { + for (int i = task_id; i < where_param_->max_num_; i += where_param_->op_parameter_.thread_num_) { + if (condition[where_param_->condition_num_ > 1 ? i : 0] == true) { + output[i] = x[where_param_->x_num_ > 1 ? i : 0]; } else { - output[i] = input2[where_param_->num2_ > 1 ? i : 0]; + output[i] = y[where_param_->y_num_ > 1 ? i : 0]; } } } diff --git a/mindspore/lite/nnacl/where.h b/mindspore/lite/nnacl/where.h index 68860bdbbe..91c0cfc6fa 100644 --- a/mindspore/lite/nnacl/where.h +++ b/mindspore/lite/nnacl/where.h @@ -23,18 +23,20 @@ typedef struct WhereParameter { OpParameter op_parameter_; // other parameter - int num_; - int num1_; - int num2_; - int number_; + int condition_num_; + int x_num_; + int y_num_; + int max_num_; + + int rank_; int thread_num_; } WhereParameter; #ifdef __cplusplus extern "C" { #endif -void Where(bool *input, const float *input1, const float *input2, float *output, WhereParameter *where_param_, - int task_id); +void WhereWithTripleInputs(const bool *condition, const float *x, const float *y, float *output, + WhereParameter *where_param_, int task_id); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index e36def21fd..fee23a5354 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -145,10 +145,22 @@ int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) { MS_LOG(WARNING) << "run kernel before_callback failed, name: " << this->name_; } } - auto ret = Run(); - if (RET_OK != ret) { - MS_LOG(ERROR) << "run kernel failed, name: " << this->name_; - return ret; + // Support ZeroShape + size_t zero_shape_num = 0; + for (auto tensor : this->out_tensors_) { + for (auto dim : tensor->shape()) { + if (dim == 0) { + zero_shape_num++; + continue; + } + } + } + if (zero_shape_num != this->out_tensors_.size()) { + auto ret = Run(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "run kernel failed, name: " << this->name_; + return ret; + } } if (after != nullptr) { if (!after(TensorVectorCast(this->in_tensors_), TensorVectorCast(this->out_tensors_), diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc index 34e03a67b5..8dc4b2d229 100644 --- a/mindspore/lite/src/ops/batch_to_space.cc +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -101,8 +101,8 @@ int BatchToSpace::InferShape(std::vector inputs, std::vectorshape(); - if (input_shape.size() != kDimension_4d) { - MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + if (input_shape.size() != kQuadrupleNum) { + MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum; return RET_PARAM_INVALID; } diff --git a/mindspore/lite/src/ops/crop_and_resize.cc b/mindspore/lite/src/ops/crop_and_resize.cc index 76badc0235..f9af64baef 100644 --- a/mindspore/lite/src/ops/crop_and_resize.cc +++ b/mindspore/lite/src/ops/crop_and_resize.cc @@ -93,8 +93,13 @@ int CropAndResize::InferShape(std::vector inputs_, std::vector output_shape; - auto boxes_tensor = inputs_[1]; - output_shape.push_back(boxes_tensor->shape()[0]); + if (inputs_[1]->data_c() != nullptr) { + auto boxes_tensor = inputs_[1]; + output_shape.push_back(boxes_tensor->shape()[0]); + } else { + output_shape.push_back(input->Batch()); + } + auto shape_tensor = inputs_[3]; auto data = reinterpret_cast(shape_tensor->data_c()); if (data == nullptr) { diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index c1f7022272..5badb82105 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -117,7 +117,7 @@ Registry DeDepthwiseConv2DRegistry(schema::PrimitiveType_DeDepthwiseConv2D, DeDe #endif int DeDepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { - if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kTripleNum) { MS_LOG(ERROR) << "inputs number is invalid"; return 1; } diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index 194e30d4ac..1109e678a5 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -76,14 +76,14 @@ int DepthToSpace::InferShape(std::vector inputs, std::vectorshape(); - if (input_shape.size() != kDimension_4d) { - MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + if (input_shape.size() != kQuadrupleNum) { + MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum; return RET_PARAM_INVALID; } int32_t block_size = GetBlockSize(); if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) { - MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size(" + MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be multiple of block_size(" << block_size << ") * block_size)!"; return RET_PARAM_INVALID; } diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index ad5bef8213..587a58242b 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -193,7 +193,7 @@ Registry DepthWiseConv2DRegistry(schema::PrimitiveType_DepthwiseConv2D, DepthWis #endif int DepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { - if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kTripleNum) { MS_LOG(ERROR) << "inputs number is invalid"; return 1; } diff --git a/mindspore/lite/src/ops/full_connection.cc b/mindspore/lite/src/ops/full_connection.cc index f34ea6f660..7ec366c870 100644 --- a/mindspore/lite/src/ops/full_connection.cc +++ b/mindspore/lite/src/ops/full_connection.cc @@ -72,7 +72,7 @@ int FullConnection::InferShape(std::vector inputs_, std::vector< if (!infer_flag()) { return RET_INFER_INVALID; } - if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) { + if ((GetHasBias() && inputs_.size() != kTripleNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) { MS_LOG(ERROR) << "Input tensors num error"; return RET_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/ops/layer_norm.cc b/mindspore/lite/src/ops/layer_norm.cc index a5b1c597c2..09d4fd5d02 100644 --- a/mindspore/lite/src/ops/layer_norm.cc +++ b/mindspore/lite/src/ops/layer_norm.cc @@ -105,7 +105,7 @@ Registry LayerNormRegistry(schema::PrimitiveType_LayerNorm, LayerNormCreator); #endif int LayerNorm::InferShape(std::vector inputs_, std::vector outputs_) { - if (outputs_.size() != kSingleNum || (inputs_.size() != kSingleNum && inputs_.size() != kMultiNum)) { + if (outputs_.size() != kSingleNum || (inputs_.size() != kSingleNum && inputs_.size() != kTripleNum)) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs_.size() << ",input size: " << inputs_.size(); return RET_PARAM_INVALID; } @@ -116,7 +116,7 @@ int LayerNorm::InferShape(std::vector inputs_, std::vectorset_format(input->format()); output->set_data_type(input->data_type()); - if (GetElementwiseAffine() && inputs_.size() != kMultiNum) { + if (GetElementwiseAffine() && inputs_.size() != kTripleNum) { MS_LOG(INFO) << "input tensor amount error"; return RET_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/ops/lsh_projection.cc b/mindspore/lite/src/ops/lsh_projection.cc index 5ab54d9be0..893fde7870 100644 --- a/mindspore/lite/src/ops/lsh_projection.cc +++ b/mindspore/lite/src/ops/lsh_projection.cc @@ -49,7 +49,7 @@ Registry LshProjectionRegistry(schema::PrimitiveType_LshProjection, LshProjectio #endif int LshProjection::InferShape(std::vector inputs_, std::vector outputs_) { - if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { + if (inputs_.size() != kDoubleNum && inputs_.size() != kTripleNum) { MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given."; return RET_ERROR; } @@ -63,7 +63,7 @@ int LshProjection::InferShape(std::vector inputs_, std::vectorDimensionSize(1) <= 32); MS_ASSERT(inputs_.at(1)->shape().size() >= 1); - if (inputs_.size() == kMultiNum) { + if (inputs_.size() == kTripleNum) { MS_ASSERT(inputs_.at(2)->shape().size() == 1); MS_ASSERT(inputs_.at(2)->DimensionSize(0) == inputs_.at(1)->DimensionSize(0)); } diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 2f03666eb9..7aca1217b9 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -138,6 +138,30 @@ PrimitiveC *PadCreator(const schema::Primitive *primitive) { return PrimitiveC:: Registry PadRegistry(schema::PrimitiveType_Pad, PadCreator); #endif +int GetPaddingFromInput(const std::vector &inputs, std::vector *paddings) { + auto paddings_tensor = inputs.at(1); + int rank = static_cast(inputs.front()->shape().size()); + MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank); + if (paddings_tensor->data_c() == nullptr) { + return RET_INFER_ERR; + } + paddings->clear(); + if (paddings_tensor->data_type() == mindspore::kNumberTypeInt64) { + auto paddings_data = reinterpret_cast(paddings_tensor->data_c()); + for (auto i = 0; i < rank; ++i) { + paddings->emplace_back(paddings_data[i * 2]); + paddings->emplace_back(paddings_data[i * 2 + 1]); + } + } else if (paddings_tensor->data_type() == mindspore::kNumberTypeInt32) { + auto paddings_data = reinterpret_cast(paddings_tensor->data_c()); + for (auto i = 0; i < rank; ++i) { + paddings->emplace_back(paddings_data[i * 2]); + paddings->emplace_back(paddings_data[i * 2 + 1]); + } + } + return RET_OK; +} + int Pad::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (this->primitive_ == nullptr) { @@ -162,29 +186,12 @@ int Pad::InferShape(std::vector inputs, std::vector outputs) if (inputs.size() == 1) { paddings = GetPaddings(); } else { - // mirror pad - auto paddings_tensor = inputs.at(1); - int rank = static_cast(inputs.front()->shape().size()); - MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank); - if (paddings_tensor->MutableData() == nullptr) { - return RET_INFER_ERR; - } - paddings.clear(); - if (paddings_tensor->data_type() == mindspore::kNumberTypeInt64) { - auto paddings_data = reinterpret_cast(paddings_tensor->MutableData()); - for (auto i = 0; i < rank; ++i) { - paddings.emplace_back(paddings_data[i * 2]); - paddings.emplace_back(paddings_data[i * 2 + 1]); - } - } else if (paddings_tensor->data_type() == mindspore::kNumberTypeInt32) { - auto paddings_data = reinterpret_cast(paddings_tensor->MutableData()); - for (auto i = 0; i < rank; ++i) { - paddings.emplace_back(paddings_data[i * 2]); - paddings.emplace_back(paddings_data[i * 2 + 1]); - } - } + GetPaddingFromInput(inputs, &paddings); } + if (paddings.empty()) { + return RET_INFER_INVALID; + } auto input_shape = input->shape(); std::vector output_shape; MS_ASSERT(input->shape().size() <= 4); diff --git a/mindspore/lite/src/ops/populate/common_populate.cc b/mindspore/lite/src/ops/populate/common_populate.cc index 8255473969..3d7fe2ef97 100644 --- a/mindspore/lite/src/ops/populate/common_populate.cc +++ b/mindspore/lite/src/ops/populate/common_populate.cc @@ -31,6 +31,8 @@ OpParameter *PopulateCommonParameter(const mindspore::lite::PrimitiveC *primitiv } Registry ZerosLikeParameterRegistry(schema::PrimitiveType_ZerosLike, PopulateCommonParameter); +Registry SizeParameterRegistry(schema::PrimitiveType_Size, PopulateCommonParameter); +Registry InvertPermutationParameterRegistry(schema::PrimitiveType_InvertPermutation, PopulateCommonParameter); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/tile_populate.cc b/mindspore/lite/src/ops/populate/tile_populate.cc index 6dd170ed0e..38c6df890f 100644 --- a/mindspore/lite/src/ops/populate/tile_populate.cc +++ b/mindspore/lite/src/ops/populate/tile_populate.cc @@ -40,7 +40,7 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive) #else auto dims = param->GetDims(); auto multiples = param->GetMultiples(); - for (size_t i = 0; i < kDimension_4d; ++i) { + for (size_t i = 0; i < kQuadrupleNum; ++i) { tile_param->multiples_[i] = 1; } if (!dims.empty() && !multiples.empty()) { diff --git a/mindspore/lite/src/ops/populate/where_populate.cc b/mindspore/lite/src/ops/populate/where_populate.cc index e349fb5a15..8a90b99363 100644 --- a/mindspore/lite/src/ops/populate/where_populate.cc +++ b/mindspore/lite/src/ops/populate/where_populate.cc @@ -15,18 +15,19 @@ */ #include "src/ops/primitive_c.h" #include "src/ops/populate/populate_register.h" +#include "nnacl/where.h" namespace mindspore { namespace lite { OpParameter *PopulateWhereParameter(const mindspore::lite::PrimitiveC *primitive) { - OpParameter *where_parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + WhereParameter *where_parameter = reinterpret_cast(malloc(sizeof(WhereParameter))); if (where_parameter == nullptr) { MS_LOG(ERROR) << "malloc Where parameter failed."; return nullptr; } memset(where_parameter, 0, sizeof(OpParameter)); - where_parameter->type_ = primitive->Type(); + where_parameter->op_parameter_.type_ = primitive->Type(); return reinterpret_cast(where_parameter); } Registry WhereParameterRegistry(schema::PrimitiveType_Where, PopulateWhereParameter); diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index e48124ffdc..6a5e526258 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -167,6 +167,7 @@ #include "src/ops/size.h" #include "src/ops/random_standard_normal.h" #include "src/ops/invert_permutation.h" +#include "src/ops/crop_and_resize.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -1018,6 +1019,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) InvertPermutation(primitive); case schema::PrimitiveType_RandomStandardNormal: return new (std::nothrow) RandomStandardNormal(primitive); + case schema::PrimitiveType_CropAndResize: + return new (std::nothrow) CropAndResize(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 51f96b3466..c4db83485e 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -36,8 +36,8 @@ namespace mindspore { namespace lite { constexpr uint32_t kSingleNum = 1; constexpr uint32_t kDoubleNum = 2; -constexpr uint32_t kMultiNum = 3; -constexpr uint32_t kDimension_4d = 4; +constexpr uint32_t kTripleNum = 3; +constexpr uint32_t kQuadrupleNum = 4; const std::set kSupportDataType = {kNumberTypeBool, kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat16}; diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index a530e732cc..cf7748e0ba 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -153,7 +153,7 @@ int Resize::InferShape(std::vector inputs_, std::vectorElementsNum(); switch (shape_size) { - case kDimension_4d: { + case kInputRank: { if (shape_tensor->data_type() == kNumberTypeInt32) { auto data = reinterpret_cast(shape_tensor->data_c()); if (data == nullptr) { @@ -212,6 +212,12 @@ int Resize::InferShape(std::vector inputs_, std::vectordata_c() == nullptr) { + return RET_INFER_INVALID; + } + output_shape.push_back(static_cast(inputs_.at(3)->data_c())[0]); + output_shape.push_back(static_cast(inputs_.at(3)->data_c())[1]); } else { MS_LOG(ERROR) << "inputs tensor size invalid."; return RET_INFER_ERR; diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 660157d86d..734cf22ad0 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -182,20 +182,20 @@ int Slice::InferShape(std::vector inputs, std::vector slice_size(GetSize()); std::vector slice_axes(GetAxes()); std::vector output_shape(input_shape.size()); - if (inputs.size() == kSliceMaxInputNum) { - if (slice_begin.empty() && inputs.at(1)->data_c() != nullptr) { + if (inputs.size() > kSliceInputNum && inputs.size() <= kSliceMaxInputNum) { + if (slice_begin.empty() && inputs.size() >= 2 && inputs.at(1)->data_c() != nullptr) { for (int i = 0; i < inputs.at(1)->ElementsNum(); i++) { slice_begin.emplace_back(static_cast(inputs.at(1)->data_c())[i]); } } - if (slice_size.empty() && inputs.at(2)->data_c() != nullptr) { + if (slice_size.empty() && inputs.size() >= 3 && inputs.at(2)->data_c() != nullptr) { for (int i = 0; i < inputs.at(2)->ElementsNum(); i++) { auto end = static_cast(inputs.at(2)->data_c())[i]; auto size = end < 0 ? end : (end == INT32_MAX ? -1 : end - slice_begin.at(i)); slice_size.emplace_back(size); } } - if (slice_axes.empty() && inputs.at(3)->data_c() != nullptr) { + if (slice_axes.empty() && inputs.size() >= 4 && inputs.at(3)->data_c() != nullptr) { for (int i = 0; i < inputs.at(3)->ElementsNum(); i++) { slice_axes.emplace_back(static_cast(inputs.at(3)->data_c())[i]); } @@ -220,12 +220,12 @@ int Slice::InferShape(std::vector inputs, std::vector= 0"; return RET_PARAM_INVALID; } - if (input_shape.at(i) <= begin.at(i)) { - MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << begin.at(i) - << " which should be <= " << input_shape.at(i); + if (input_shape.at(i) != 0 && input_shape.at(i) <= begin.at(i)) { + MS_LOG(ERROR) << "Invalid begin input!begin[" << i << "]=" << begin.at(i) << " which should be > " + << input_shape.at(i); return RET_PARAM_INVALID; } - if (size.at(i) > (input_shape.at(i) - begin.at(i))) { + if (input_shape.at(i) != 0 && size.at(i) > (input_shape.at(i) - begin.at(i))) { MS_LOG(ERROR) << "Invalid size input " << size.at(i) << " which should be <= " << input_shape.at(i) - begin.at(i); return RET_PARAM_INVALID; } diff --git a/mindspore/lite/src/ops/space_to_batch.cc b/mindspore/lite/src/ops/space_to_batch.cc index 4c13eea799..5a31ba90d7 100644 --- a/mindspore/lite/src/ops/space_to_batch.cc +++ b/mindspore/lite/src/ops/space_to_batch.cc @@ -100,9 +100,9 @@ int SpaceToBatch::InferShape(std::vector inputs, std::vectorshape(); - if (input_shape.size() != kDimension_4d) { + if (input_shape.size() != kQuadrupleNum) { MS_LOG(ERROR) << "Space_to_batch op only support 4D input currently. But got %d dimensionality input." - << kDimension_4d; + << kQuadrupleNum; return RET_ERROR; } diff --git a/mindspore/lite/src/ops/space_to_batch_nd.cc b/mindspore/lite/src/ops/space_to_batch_nd.cc index 273f2d3555..fda4d4d2ae 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.cc +++ b/mindspore/lite/src/ops/space_to_batch_nd.cc @@ -103,8 +103,8 @@ int SpaceToBatchND::InferShape(std::vector inputs, std::vectorshape(); - if (input_shape.size() != kDimension_4d) { - MS_LOG(ERROR) << "input shape dimension size only support " << kDimension_4d << " now!"; + if (input_shape.size() != kQuadrupleNum) { + MS_LOG(ERROR) << "input shape dimension size only support " << kQuadrupleNum << " now!"; return RET_ERROR; } auto block_shape = GetBlockShape(); diff --git a/mindspore/lite/src/ops/space_to_depth.cc b/mindspore/lite/src/ops/space_to_depth.cc index f719bf59af..e6c5eddcf3 100644 --- a/mindspore/lite/src/ops/space_to_depth.cc +++ b/mindspore/lite/src/ops/space_to_depth.cc @@ -78,8 +78,8 @@ int SpaceToDepth::InferShape(std::vector inputs, std::vectorshape(); - if (input_shape.size() != kDimension_4d) { - MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; + if (input_shape.size() != kQuadrupleNum) { + MS_LOG(ERROR) << "input shape dimension size should == " << kQuadrupleNum; return RET_ERROR; } diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index 55294d0d27..5207e7d64e 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -91,7 +91,7 @@ int TopK::InferShape(std::vector inputs_, std::vector output } auto input = inputs_.front(); MS_ASSERT(input != nullptr); - if (input->shape().size() == kDimension_4d && input->format() != schema::Format::Format_NHWC) { + if (input->shape().size() == kQuadrupleNum && input->format() != schema::Format::Format_NHWC) { MS_LOG(ERROR) << "topk only support NHWC now!"; return RET_FORMAT_ERR; } diff --git a/mindspore/lite/src/ops/where.cc b/mindspore/lite/src/ops/where.cc index abeb4e61d9..95029d851d 100644 --- a/mindspore/lite/src/ops/where.cc +++ b/mindspore/lite/src/ops/where.cc @@ -66,26 +66,12 @@ int Where::InferShape(std::vector inputs_, std::vector outpu MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); + // Need to dynamically allocate at runtime. if (inputs_.size() == kSingleNum) { - auto input0 = inputs_.at(0); - if (input0->data_c() == nullptr) { - MS_LOG(ERROR) << "input0 is empty, tensor cannot be inferred yet"; - return RET_INFER_INVALID; - } - int dim_size = input0->shape().size(); - auto data_ptr = reinterpret_cast(input0->data_c()); - int true_num = 0; - for (int i = 0; i < input0->ElementsNum(); i++) { - if (*data_ptr) { - true_num++; - } - } - std::vector output_shape = {true_num, dim_size}; - outputs_.at(0)->set_shape(output_shape); - return RET_OK; + return RET_INFER_INVALID; } - if (inputs_.size() < kMultiNum || outputs_.size() != kSingleNum) { + if (inputs_.size() < kTripleNum || outputs_.size() != kSingleNum) { MS_LOG(ERROR) << "where input or output number invalid, Input size:" << inputs_.size() << ", output size: " << outputs_.size(); return RET_INPUT_TENSOR_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc index 4f81d7be7e..535d7f721d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc @@ -122,7 +122,7 @@ int ResizeBaseCPUKernel::CalculateNewHeightWidth() { } int ResizeBaseCPUKernel::CheckInputsOuputs() { - if (in_tensors_.size() <= lite::kDoubleNum) { + if (in_tensors_.size() <= lite::kQuadrupleNum) { for (size_t i = 0; i < in_tensors_.size(); i++) { auto input = in_tensors_.at(i); if (input == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc index 19ce8b28da..7fa29d525e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc @@ -71,8 +71,9 @@ int SplitBaseCPUKernel::ReSize() { num_unit_ = param->split_count_ * param->num_split_; thread_n_num_ = MSMIN(thread_count_, num_unit_); - MS_ASSERT(thread_n_num_); - thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + if (thread_n_num_ != 0) { + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + } return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index c7ef11fa5e..4386125476 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -114,6 +114,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NotEqual, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Less, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LessEqual, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_LessEqual, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Greater, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_GreaterEqual, LiteKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop_and_resize_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/crop_and_resize_fp32.cc index 3b49691b53..1331f9691b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/crop_and_resize_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop_and_resize_fp32.cc @@ -43,6 +43,7 @@ int CropAndResizeCPUKernel::ReSize() { } int CropAndResizeCPUKernel::MallocTmpBuffer() { + batch_ = out_tensors_[0]->Batch(); // Malloc buffer to save coordinate. // For mode CROP_AND_RESIZE, different output batches require different cache coordinates. int c = in_tensors_.at(0)->Channel(); @@ -153,7 +154,6 @@ int CropAndResizeCPUKernel::Run() { ret = PrepareResizeBilinear(input_shape.data(), out_tensors_.at(0)->shape().data(), CalculateAlignCorners, y_bottoms_, y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_); } else { - batch_ = out_tensors_[0]->Batch(); auto boxes = reinterpret_cast(in_tensors_.at(1)->data_c()); auto box_idx = reinterpret_cast(in_tensors_.at(2)->data_c()); ret = PrepareCropAndResizeBilinear(input_shape.data(), boxes, box_idx, out_tensors_.at(0)->shape().data(), diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc index 1526b10577..1e62da62e6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.cc @@ -38,7 +38,9 @@ int FillCPUKernel::Init() { int FillCPUKernel::ReSize() { data_size_ = out_tensors_.front()->ElementsNum(); thread_sz_count_ = MSMIN(thread_count_, data_size_); - thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + if (thread_sz_count_ != 0) { + thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h index 09228635f1..92e2f7bffc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fill_fp32.h @@ -38,8 +38,8 @@ class FillCPUKernel : public LiteKernel { int DoFill(int task_id); private: - int thread_sz_count_; - int thread_sz_stride_; + int thread_sz_count_ = 0; + int thread_sz_stride_ = 0; int data_size_; float src_data_; float *out_ptr_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc index e5fc4d2dca..d18c530fc6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.cc @@ -67,7 +67,9 @@ int GatherNdCPUKernel::ReSize() { } (void)memset(in_offset_, 0, count_ * sizeof(int)); thread_sz_count_ = MSMIN(thread_count_, count_); - thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); + if (thread_sz_count_ != 0) { + thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h index b24978abcf..51cba5945b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gatherNd_fp32.h @@ -42,8 +42,8 @@ class GatherNdCPUKernel : public LiteKernel { private: void InitOffset(); - int thread_sz_count_; - int thread_sz_stride_; + int thread_sz_count_ = 0; + int thread_sz_stride_ = 0; int count_; int area_; int *in_offset_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/non_max_suppression_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/non_max_suppression_fp32.cc index 7a0e671666..b80cfc7960 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/non_max_suppression_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/non_max_suppression_fp32.cc @@ -97,15 +97,23 @@ int NonMaxSuppressionCPUKernel::GetParams() { int NonMaxSuppressionCPUKernel::PreProcess() { return GetParams(); } +void ExpandDims(std::vector *shape, size_t size) { + for (size_t i = 0; i < size; i++) { + shape->insert(shape->begin(), 1); + } +} + int NonMaxSuppressionCPUKernel::Run() { auto box_tensor = in_tensors_.at(kBoxTensorIndex); if (box_tensor == nullptr) { return RET_ERROR; } + bool simple_out = false; auto box_dims = box_tensor->shape(); // batch, box_num, 4 constexpr size_t kBoxTensorDims = 3; if (box_dims.size() != kBoxTensorDims) { - return RET_ERROR; + ExpandDims(&box_dims, kBoxTensorDims - box_dims.size()); + simple_out = true; } constexpr size_t kBoxCoordIndex = 2; if (box_dims[kBoxCoordIndex] != kBoxPointNum) { @@ -119,7 +127,7 @@ int NonMaxSuppressionCPUKernel::Run() { auto score_dims = score_tensor->shape(); // batch, class, box_num constexpr size_t kScoreTensorDims = 3; if (score_dims.size() != kScoreTensorDims) { - return RET_ERROR; + ExpandDims(&score_dims, kScoreTensorDims - score_dims.size()); } constexpr size_t kBatchIndex = 0; if (score_dims.at(kBatchIndex) != box_dims.at(kBatchIndex)) { @@ -206,11 +214,29 @@ int NonMaxSuppressionCPUKernel::Run() { } auto output = out_tensors_.at(0); int selected_num = static_cast(selected_index.size()); - const int output_last_dim = 3; - output->set_shape({selected_num, output_last_dim}); - MS_ASSERT(output_last_dim * sizeof(int32_t) == sizeof(NMSIndex)); - int32_t *out_data = reinterpret_cast(output->MutableData()); - memcpy(out_data, selected_index.data(), selected_index.size() * sizeof(NMSIndex)); + if (!simple_out) { + const int output_last_dim = 3; + output->set_shape({selected_num, output_last_dim}); + MS_ASSERT(output_last_dim * sizeof(int32_t) == sizeof(NMSIndex)); + auto *out_data = reinterpret_cast(output->MutableData()); + if (out_data == nullptr) { + MS_LOG(ERROR) << "out_data is nullptr."; + return RET_ERROR; + } + memcpy(out_data, selected_index.data(), selected_index.size() * sizeof(NMSIndex)); + } else { + output->set_shape({selected_num}); + std::vector result; + for (size_t i = 0; i < selected_index.size(); i++) { + result.push_back(selected_index[i].box_index_); + } + auto *out_data = reinterpret_cast(output->MutableData()); + if (out_data == nullptr) { + MS_LOG(ERROR) << "out_data is nullptr."; + return RET_ERROR; + } + memcpy(out_data, result.data(), result.size() * sizeof(int)); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pad_fp32.cc index 849fe85f5e..159c93a54a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pad_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad_fp32.cc @@ -29,7 +29,8 @@ using mindspore::schema::PrimitiveType_Pad; namespace mindspore::kernel { namespace { constexpr size_t kMirrorPadInputSize = 2; -} +constexpr size_t kPadMaxInputSize = 2; +} // namespace int PadCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; @@ -381,6 +382,9 @@ int PadCPUKernel::HandleMirrorPad() { int PadCPUKernel::Run() { int error_code; if (pad_param_->pad_mode_ == static_cast(schema::PaddingMode_CONSTANT)) { + if (in_tensors_.size() == kPadMaxInputSize) { + CopyPaddingFromInput(); + } auto output = out_tensors_.at(0); int output_size = output->ElementsNum(); auto output_data = reinterpret_cast(output->data_c()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc index 4899b24437..16b133989b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reverse_fp32.cc @@ -136,4 +136,5 @@ int ReverseCPUKernel::Run() { } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reverse, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Reverse, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.cc index ac67e86d47..cdb9f498ec 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.cc @@ -15,11 +15,13 @@ */ #include "src/runtime/kernel/arm/fp32/where_fp32.h" #include +#include #include "schema/model_generated.h" #include "nnacl/where.h" #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" +#include "nnacl/common_func.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -28,18 +30,28 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Where; namespace mindspore::kernel { +constexpr uint32_t kSingleNum = 1; +constexpr uint32_t kTripleNum = 3; int WhereCPUKernel::Init() { where_param_->op_parameter_.thread_num_ = thread_count_; return RET_OK; } +int WhereCPUKernel::PreProcess() { + if (in_tensors_.size() == kTripleNum) { + return LiteKernel::PreProcess(); + } else { + return RET_OK; + } +} + int WhereCPUKernel::DoExcute(int task_id) { - MS_ASSERT(input_data); - MS_ASSERT(input_data1); - MS_ASSERT(input_data2); - MS_ASSERT(output_data); + MS_ASSERT(condition_); + MS_ASSERT(x_); + MS_ASSERT(y_); + MS_ASSERT(output_data_); MS_ASSERT(where_param_); - Where(input_data, input_data1, input_data2, output_data, where_param_, task_id); + WhereWithTripleInputs(condition_, x_, y_, output_data_, where_param_, task_id); return RET_OK; } @@ -52,29 +64,68 @@ int WhereRun(void *cdata, int task_id) { } return RET_OK; } -int WhereCPUKernel::Run() { + +int WhereCPUKernel::RunWithSingleInput() { auto input = in_tensors_.at(0); MS_ASSERT(input); - auto input1 = in_tensors_.at(1); - MS_ASSERT(input1); - auto input2 = in_tensors_.at(2); - MS_ASSERT(input2); - int num = input->ElementsNum(); - int num1_ = input1->ElementsNum(); - int num2_ = input2->ElementsNum(); + condition_ = reinterpret_cast(input->data_c()); + where_param_->condition_num_ = input->ElementsNum(); + where_param_->rank_ = input->shape().size(); + int strides[8]; + ComputeStrides(in_tensors_.at(0)->shape().data(), strides, where_param_->rank_); - input_data = reinterpret_cast(input->MutableData()); - input_data1 = reinterpret_cast(input1->MutableData()); - input_data2 = reinterpret_cast(input2->MutableData()); - output_data = reinterpret_cast(out_tensors_.at(0)->MutableData()); - int num_max = num > num1_ ? num : (num1_ > num2_ ? num1_ : num2_); - where_param_->num_ = num; - where_param_->num1_ = num1_; - where_param_->num2_ = num2_; - where_param_->number_ = num_max; + auto data = context_->allocator->Malloc(where_param_->condition_num_ * where_param_->rank_ * sizeof(int32_t)); + int *result = reinterpret_cast(data); - if (((num != 1) && (num != num_max)) || ((num1_ != 1) && (num1_ != num_max)) || - ((num2_ != 1) && (num2_ != num_max))) { + int result_index = 0; + int true_num = 0; + for (int index = 0; index < where_param_->condition_num_; index++) { + if (condition_[index]) { + true_num++; + int dim = index; + for (int j = 0; j < where_param_->rank_; j++) { + result[result_index++] = dim / strides[j]; + dim %= strides[j]; + } + } + } + out_tensors_.at(0)->set_data_type(kNumberTypeInt32); + std::vector output_shape = {true_num, where_param_->rank_}; + out_tensors_.at(0)->set_shape(output_shape); + out_tensors_.at(0)->FreeData(); + auto out_data = out_tensors_.at(0)->MutableData(); + if (out_data == nullptr) { + MS_LOG(ERROR) << "malloc out tensor failed."; + return RET_ERROR; + } + memcpy(out_data, result, true_num * where_param_->rank_ * sizeof(int32_t)); + context_->allocator->Free(data); + return RET_OK; +} + +int WhereCPUKernel::RunWithTripleInputs() { + auto condition = in_tensors_.at(0); + MS_ASSERT(condition); + auto x = in_tensors_.at(1); + MS_ASSERT(x); + auto y = in_tensors_.at(2); + MS_ASSERT(y); + int condition_nums = condition->ElementsNum(); + int x_num = x->ElementsNum(); + int y_num = y->ElementsNum(); + + condition_ = reinterpret_cast(condition->data_c()); + x_ = reinterpret_cast(x->data_c()); + y_ = reinterpret_cast(y->data_c()); + output_data_ = reinterpret_cast(out_tensors_.at(0)->data_c()); + int num_max = condition_nums > x_num ? condition_nums : (x_num > y_num ? x_num : y_num); + where_param_->condition_num_ = condition_nums; + where_param_->x_num_ = x_num; + where_param_->y_num_ = y_num; + where_param_->max_num_ = num_max; + + if (((condition_nums != 1) && (condition_nums != num_max)) || ((x_num != 1) && (x_num != num_max)) || + ((y_num != 1) && (y_num != num_max))) { MS_LOG(ERROR) << "The length of three inputs are not equal to 1 or length of output, which is unacceptable"; return RET_ERROR; } @@ -90,6 +141,22 @@ int WhereCPUKernel::Run() { return RET_OK; } +int WhereCPUKernel::Run() { + int ret = RET_ERROR; + if (in_tensors_.size() == kSingleNum) { + ret = RunWithSingleInput(); + } else if (in_tensors_.size() == kTripleNum) { + ret = RunWithTripleInputs(); + } else { + MS_LOG(ERROR) << "in tensor size is invalid. size is " << in_tensors_.size(); + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "Where op run failed."; + } + return ret; +} + REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Where, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Where, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Where, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.h index 830a65fe5b..08a256fc93 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/where_fp32.h @@ -37,8 +37,11 @@ class WhereCPUKernel : public LiteKernel { ~WhereCPUKernel() = default; int Init() override; + int PreProcess() override; int ReSize() override { return 0; } int Run() override; + int RunWithSingleInput(); + int RunWithTripleInputs(); int DoExcute(int task_id); protected: @@ -47,10 +50,10 @@ class WhereCPUKernel : public LiteKernel { WhereParameter *where_param_; private: - bool *input_data; - float *input_data1; - float *input_data2; - float *output_data; + bool *condition_; + float *x_; + float *y_; + float *output_data_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_WHERE_H_ diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index a7c8fdd07f..2152c07d22 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -92,7 +92,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // postconvert pass { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer fusionOptimizer; if (!ctx.trainModel) { @@ -113,9 +113,9 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - if (ctx.fmk != converter::FmkType_TF) { + { // format transform - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer formatTransOptimizer; @@ -126,19 +126,20 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } formatTransPass->SetQuantType(ctx.quantType); formatTransPass->SetFmk(ctx.fmk); - formatTransOptimizer.AddPass(formatTransPass); - formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); + if (ctx.fmk != converter::FmkType_TF) { + formatTransOptimizer.AddPass(formatTransPass); + formatTransOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); + } status = formatTransOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; return status; } } - { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer formatTransOptimizer; formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); @@ -156,7 +157,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer formatTransOptimizer; if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { @@ -172,7 +173,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer fusionOptimizer; fusionOptimizer.AddPass(new (std::nothrow) MulAddFusionPass()); @@ -187,7 +188,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // do quantization if (ctx.fmk != converter::FmkType_TF) { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer tensorQuantOptimizer; tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); @@ -203,7 +204,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // insert quantNode and deQuantNode if (ctx.fmk != converter::FmkType_TF) { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer quantNodeOptimizer; auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); @@ -236,7 +237,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // switch pass { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer switchOptimizer; switchOptimizer.AddPass(new (std::nothrow) SwitchPass()); @@ -262,7 +263,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // tensor name { - // init old node indecies + // init old node indices auto old_nodes = GetGraphNodes(); Optimizer nameOptimizer; nameOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index f949d8128a..0513490203 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -30,9 +30,6 @@ namespace lite { #define kOutputNum 1 STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { - if (fmkType == converter::FmkType_TF) { - return RET_OK; - } MS_ASSERT(graph != nullptr); auto status = DoModelInputFormatTrans(graph); if (status != RET_OK) { @@ -124,6 +121,14 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { } beforeNodeType = kNCHW2NHWC; afterNodeType = kNHWC2NCHW; + } else if (fmkType == converter::FmkType_TF) { + auto &node = *iter; + if (IsContain(GetNhwcOpList(), GetCNodeTType(**iter)) && GetFormat(node) == schema::Format_NCHW) { + beforeNodeType = kNCHW2NHWC; + afterNodeType = kNHWC2NCHW; + } else { + continue; + } } else { MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; return RET_ERROR; @@ -244,5 +249,23 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } + +int FormatTransPass::GetFormat(const std::unique_ptr &node) { + switch (node->primitive->value.type) { + case schema::PrimitiveType_Conv2D: + return node->primitive->value.AsConv2D()->format; + case schema::PrimitiveType_DeConv2D: + return node->primitive->value.AsDeConv2D()->format; + case schema::PrimitiveType_DeDepthwiseConv2D: + return node->primitive->value.AsDeDepthwiseConv2D()->format; + case schema::PrimitiveType_DepthwiseConv2D: + return node->primitive->value.AsDepthwiseConv2D()->format; + case schema::PrimitiveType_Pooling: + return node->primitive->value.AsPooling()->format; + default: + return schema::Format_NHWC; + } +} + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h index 14f4e1ab59..1f6f6d9003 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H #define MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H +#include #include "tools/converter/optimizer.h" #include "tools/common/graph_util.h" #include "tools/converter/converter_flags.h" @@ -46,6 +47,8 @@ class FormatTransPass : public GraphPass { STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); + int GetFormat(const std::unique_ptr &node); + protected: size_t id = 0; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc index 7f758eb414..27ac7c988f 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_cast_parser.cc @@ -41,18 +41,11 @@ STATUS TFCastParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "new attr failed"; return RET_NULL_PTR; } - - auto src_type = TensorFlowUtils::ParseAttrDataType(tf_op, "SrcT"); - if (src_type == kTypeUnknown) { - MS_LOG(ERROR) << "Get attr SrcT failed"; - return RET_ERROR; - } auto dst_type = TensorFlowUtils::ParseAttrDataType(tf_op, "DstT"); if (dst_type == kTypeUnknown) { MS_LOG(ERROR) << "Get attr DstT failed"; return RET_ERROR; } - attr->srcT = src_type; attr->dstT = dst_type; primitive->value.type = schema::PrimitiveType_Cast; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc index 28cdc7fc64..f1859fffca 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_crop_and_resize_parser.cc @@ -71,7 +71,7 @@ STATUS TFCropAndResizeParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "Do not support method: " << attr_value.s(); } - primitive->value.type = schema::PrimitiveType_Resize; + primitive->value.type = schema::PrimitiveType_CropAndResize; primitive->value.value = attr.release(); *primitiveC = PrimitiveC::Create(primitive.release()); if (*primitiveC == nullptr) { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 747bf3ae27..3a4a7b44a8 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -195,6 +195,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co MS_LOG(ERROR) << "param_value is nullptr"; return RET_ERROR; } + param_value->set_tensor_type(type); if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { auto tensor_data = new (std::nothrow) float[shape_size]; if (tensor_proto.float_val_size() == 1) { @@ -266,24 +267,35 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef &node_def, co MS_LOG(ERROR) << "new data failed"; return RET_ERROR; } - const auto origin_data = reinterpret_cast(tensor_proto.tensor_content().data()); - for (int i = 0; i < shape_size; ++i) { - if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { - MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32"; - tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN; - } else { - tensor_data[i] = static_cast(origin_data[i]); + if (tensor_shape.dim_size() == 0) { // scalar + const auto &origin_data = tensor_proto.int64_val(); + for (int i = 0; i < tensor_proto.int64_val_size(); ++i) { + if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { + MS_LOG(ERROR) << "int64 data " << origin_data[i] << "too big to fit into int32"; + return RET_ERROR; + } else { + tensor_data[i] = static_cast(origin_data[i]); + } + } + } else { + const auto origin_data = reinterpret_cast(tensor_proto.tensor_content().data()); + for (int i = 0; i < shape_size; ++i) { + if (origin_data[i] > static_cast(INT32_MAX) || origin_data[i] < static_cast(INT32_MIN)) { + MS_LOG(WARNING) << "int64 data " << origin_data[i] << "too big to fit into int32"; + tensor_data[i] = origin_data[i] > 0 ? INT32_MAX : INT32_MIN; + } else { + tensor_data[i] = static_cast(origin_data[i]); + } } } param_value->SetTensorData(tensor_data, shape_size * sizeof(int32_t)); } else { - MS_LOG(ERROR) << "Unsupport dataType: " << type; + MS_LOG(ERROR) << "Unsupported dataType: " << type; return RET_ERROR; } std::vector param_shape(shape_vector->begin(), shape_vector->end()); param_value->set_tensor_shape(param_shape); - param_value->set_tensor_type(type); if (TensorFlowUtils::FindAttrValue(node_def, "data_format", const_cast(&attr_value))) { auto format = mindspore::lite::TensorFlowUtils::ParseNodeFormat(node_def); if (format == schema::Format_NUM_OF_FORMAT) { @@ -307,7 +319,6 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) { type = TensorFlowUtils::GetTFDataType(attr_value.type()); } - auto type_ptr = TypeIdToType(type); std::vector shape; if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { @@ -327,7 +338,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa } else { graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names } - + if (type == kNumberTypeInt64) { + type = kNumberTypeInt32; + } + auto type_ptr = TypeIdToType(type); auto abstract_tensor = std::make_shared(type_ptr, shape_vector); if (abstract_tensor == nullptr) { MS_LOG(ERROR) << "abstract_tensor is nullptr"; @@ -474,16 +488,16 @@ STATUS TFModelParser::ConvertSubgraph() { std::vector sub_graph_inputs; for (int j = 0; j < input_arg_size; j++) { auto &input_arg = tf_sub_signature.input_arg(j); - auto paramter = sub_func_graph->add_parameter(); - paramter->set_name(input_arg.name()); - anf_sub_node_map[input_arg.name()] = paramter; + auto parameter = sub_func_graph->add_parameter(); + parameter->set_name(input_arg.name()); + anf_sub_node_map[input_arg.name()] = parameter; auto root_inputs = cnode->inputs(); if (op_type == schema::PrimitiveType_While) { - paramter->set_abstract(root_inputs[j + 1]->abstract()); + parameter->set_abstract(root_inputs[j + 1]->abstract()); } else { - paramter->set_abstract(root_inputs[j + 2]->abstract()); + parameter->set_abstract(root_inputs[j + 2]->abstract()); } - sub_graph_inputs.emplace_back(paramter); + sub_graph_inputs.emplace_back(parameter); } std::map tf_sub_node_map; for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { diff --git a/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc index fc893ebda3..328ab146ed 100644 --- a/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unused_cast_node_remove_pass.cc @@ -54,7 +54,7 @@ bool RemoveUnusedCastOpPass::Run(const FuncGraphPtr &func_graph) { MS_ASSERT(input_type != nullptr); auto input_type_value = input_type->type_id(); - if (cast_cnode->inputs().size() != lite::kMultiNum || !utils::isa(cast_cnode->input(2))) { + if (cast_cnode->inputs().size() != lite::kTripleNum || !utils::isa(cast_cnode->input(2))) { MS_LOG(ERROR) << "Second input of cast should be a ValueNode"; return RET_ERROR; }