Browse Source

!13322 [MS][LITE] fix bug of merge op infershape

From: @mengyuanli
Reviewed-by: @zhang_xue_tong
Signed-off-by: @zhang_xue_tong
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
8ed460f93f
8 changed files with 29 additions and 79 deletions
  1. +1
    -0
      mindspore/lite/nnacl/infer/common_infer.h
  2. +15
    -42
      mindspore/lite/nnacl/infer/merge_infer.c
  3. +0
    -2
      mindspore/lite/nnacl/infer/merge_infer.h
  4. +1
    -0
      mindspore/lite/nnacl/tensor_c.h
  5. +9
    -32
      mindspore/lite/src/common/tensor_util.cc
  6. +1
    -1
      mindspore/lite/src/common/tensor_util.h
  7. +1
    -1
      mindspore/lite/src/kernel_registry.cc
  8. +1
    -1
      mindspore/lite/src/runtime/infer_manager.cc

+ 1
- 0
mindspore/lite/nnacl/infer/common_infer.h View File

@@ -134,6 +134,7 @@ typedef struct vvector {
} vvector; } vvector;


typedef struct TensorListC { typedef struct TensorListC {
bool is_ready_;
int data_type_; int data_type_;
int format_; int format_;




+ 15
- 42
mindspore/lite/nnacl/infer/merge_infer.c View File

@@ -18,36 +18,19 @@
#include <string.h> #include <string.h>
#include "nnacl/infer/infer_register.h" #include "nnacl/infer/infer_register.h"


int MergeAbleToInfer(const TensorC *const *inputs, size_t inputs_size) {
bool MergeAbleToInfer(const TensorC *const *inputs, size_t inputs_size) {
for (size_t i = 0; i < inputs_size; i++) { for (size_t i = 0; i < inputs_size; i++) {
if (inputs[i]->shape_size_ == 0) {
return HasZeroShape;
}
if (inputs[i]->data_ == NULL) {
return NotAble;
if (!inputs[i]->is_ready_) {
return false;
} }
} }
return Able;
return true;
} }


int MergeInfer(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) {
int MergeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) {
for (size_t i = 0; i < inputs_size; i++) { for (size_t i = 0; i < inputs_size; i++) {
SetDataTypeFormat(outputs[i], inputs[i]);
if (((TensorListC *)inputs[i])->data_type_ == kObjectTypeTensorType) {
TensorListC *input_tensorlist = (TensorListC *)inputs[i];
TensorListC *output_tensorlist = (TensorListC *)outputs[i];
ShapeSet(output_tensorlist->element_shape_, &output_tensorlist->element_shape_size_,
input_tensorlist->element_shape_, input_tensorlist->element_shape_size_);
output_tensorlist->max_elements_num_ = input_tensorlist->max_elements_num_;
output_tensorlist->tensors_data_type_ = input_tensorlist->tensors_data_type_;

output_tensorlist->element_num_ = input_tensorlist->element_num_;
for (size_t j = 0; j < output_tensorlist->element_num_; j++) {
memcpy(&output_tensorlist->tensors_[j], &input_tensorlist->tensors_[j], sizeof(TensorC));
}
} else {
SetShapeTensor(outputs[i], inputs[i]);
}
outputs[i] = inputs[i];
inputs[i] = NULL;
} }
return NNACL_OK; return NNACL_OK;
} }
@@ -55,9 +38,10 @@ int MergeInfer(const TensorC *const *inputs, size_t inputs_size, TensorC **outpu
int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) { OpParameter *parameter) {
#ifdef Debug #ifdef Debug
int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
if (check_ret != NNACL_OK) {
return check_ret;
for (size_t i = 0; i < inputs_size; i++) {
if (inputs[i] == NULL) {
return NNACL_NULL_PTR;
}
} }
if (inputs_size != 2 * outputs_size) { if (inputs_size != 2 * outputs_size) {
return NNACL_ERR; return NNACL_ERR;
@@ -67,12 +51,6 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
if (!parameter->infer_flag_) { if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID; return NNACL_INFER_INVALID;
} }
for (size_t i = 0; i < outputs_size; ++i) {
outputs[i]->data_type_ = inputs[i]->data_type_;
}
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}


const TensorC *const *left_part_inputs = inputs; const TensorC *const *left_part_inputs = inputs;
size_t left_part_inputs_size = inputs_size / 2; size_t left_part_inputs_size = inputs_size / 2;
@@ -80,17 +58,12 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
const TensorC *const *right_part_inputs = inputs + left_part_inputs_size; const TensorC *const *right_part_inputs = inputs + left_part_inputs_size;
size_t right_part_inputs_size = inputs_size / 2; size_t right_part_inputs_size = inputs_size / 2;


if (MergeAbleToInfer(left_part_inputs, left_part_inputs_size) == Able) {
return MergeInfer(left_part_inputs, left_part_inputs_size, outputs, outputs_size);
}

if (MergeAbleToInfer(right_part_inputs, right_part_inputs_size) == Able) {
return MergeInfer(right_part_inputs, right_part_inputs_size, outputs, outputs_size);
if (MergeAbleToInfer(left_part_inputs, left_part_inputs_size)) {
return MergeInfer((TensorC **)left_part_inputs, left_part_inputs_size, outputs, outputs_size);
} }


if (MergeAbleToInfer(left_part_inputs, left_part_inputs_size) == HasZeroShape &&
MergeAbleToInfer(right_part_inputs, right_part_inputs_size) == HasZeroShape) {
return MergeInfer(left_part_inputs, left_part_inputs_size, outputs, outputs_size);
if (MergeAbleToInfer(right_part_inputs, right_part_inputs_size)) {
return MergeInfer((TensorC **)right_part_inputs, right_part_inputs_size, outputs, outputs_size);
} }


return NNACL_INFER_INVALID; return NNACL_INFER_INVALID;


+ 0
- 2
mindspore/lite/nnacl/infer/merge_infer.h View File

@@ -23,8 +23,6 @@
extern "C" { extern "C" {
#endif #endif


enum InferStatus { Able, NotAble, HasZeroShape };

int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter); OpParameter *parameter);




+ 1
- 0
mindspore/lite/nnacl/tensor_c.h View File

@@ -18,6 +18,7 @@
#include "nnacl/op_base.h" #include "nnacl/op_base.h"


typedef struct TensorC { typedef struct TensorC {
bool is_ready_;
int data_type_; int data_type_;
int format_; int format_;
void *data_; void *data_;


+ 9
- 32
mindspore/lite/src/common/tensor_util.cc View File

@@ -61,11 +61,13 @@ int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors, std::vector
return RET_OK; return RET_OK;
} }


void TensorC2LiteTensor(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out) {
void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out) {
for (size_t i = 0; i < tensors_in.size(); ++i) { for (size_t i = 0; i < tensors_in.size(); ++i) {
tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_));
tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_));
tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_});
if (tensors_in[i] != nullptr) {
tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_));
tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_));
tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_});
}
} }
} }


@@ -108,6 +110,7 @@ TensorC *NewTensorC() {
} }


void Tensor2TensorC(Tensor *src, TensorC *dst) { void Tensor2TensorC(Tensor *src, TensorC *dst) {
dst->is_ready_ = src->IsReady();
dst->format_ = src->format(); dst->format_ = src->format();
dst->data_ = src->data_c(); dst->data_ = src->data_c();
dst->data_type_ = src->data_type(); dst->data_type_ = src->data_type();
@@ -124,6 +127,7 @@ void TensorC2Tensor(TensorC *src, Tensor *dst) {
} }


int TensorList2TensorListC(TensorList *src, TensorListC *dst) { int TensorList2TensorListC(TensorList *src, TensorListC *dst) {
dst->is_ready_ = src->IsReady();
dst->data_type_ = static_cast<TypeIdC>(src->data_type()); dst->data_type_ = static_cast<TypeIdC>(src->data_type());
dst->format_ = src->format(); dst->format_ = src->format();
dst->element_num_ = src->shape().empty() ? 0 : src->tensors().size(); dst->element_num_ = src->shape().empty() ? 0 : src->tensors().size();
@@ -165,34 +169,7 @@ int GenerateMergeOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vect
std::vector<TensorC *> *out_tensor_c) { std::vector<TensorC *> *out_tensor_c) {
int ret = RET_OK; int ret = RET_OK;
for (size_t i = 0; i < outputs->size(); i++) { for (size_t i = 0; i < outputs->size(); i++) {
if (inputs.at(i)->data_type() == kObjectTypeTensorType) {
auto *output_tensorlist = reinterpret_cast<TensorListC *>(malloc(sizeof(TensorListC)));
if (output_tensorlist == nullptr) {
return RET_ERROR;
}
memset(output_tensorlist, 0, sizeof(TensorListC));
output_tensorlist->element_num_ = inputs[i]->shape().empty() ? 0 : inputs[i]->shape().at(0);
if (output_tensorlist->element_num_ != 0) {
output_tensorlist->tensors_ =
reinterpret_cast<TensorC *>(malloc(output_tensorlist->element_num_ * sizeof(TensorC)));
if (output_tensorlist->tensors_ == nullptr) {
free(output_tensorlist);
output_tensorlist = nullptr;
return RET_ERROR;
}
memset(output_tensorlist->tensors_, 0, output_tensorlist->element_num_ * sizeof(TensorC));
}

out_tensor_c->push_back(reinterpret_cast<TensorC *const>(output_tensorlist));
} else {
auto *output_tensor = NewTensorC();
if (output_tensor == nullptr) {
MS_LOG(ERROR) << "malloc tensor_c failed";
ret = RET_ERROR;
break;
}
out_tensor_c->push_back(reinterpret_cast<TensorC *const>(output_tensor));
}
out_tensor_c->push_back(nullptr);
} }
return ret; return ret;
} }


+ 1
- 1
mindspore/lite/src/common/tensor_util.h View File

@@ -26,7 +26,7 @@ namespace mindspore {
namespace lite { namespace lite {
int InputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out); int InputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out);
int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out); int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors_in, std::vector<TensorC *> *tensors_out);
void TensorC2LiteTensor(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out);
void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<lite::Tensor *> *tensors_out);
void FreeAllTensorC(std::vector<TensorC *> *tensors_in); void FreeAllTensorC(std::vector<TensorC *> *tensors_in);
void FreeTensorListC(TensorListC *tensorListC); void FreeTensorListC(TensorListC *tensorListC);
TensorC *NewTensorC(); TensorC *NewTensorC();


+ 1
- 1
mindspore/lite/src/kernel_registry.cc View File

@@ -64,7 +64,7 @@ int KernelRegistry::Init() {


kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) { kernel::KernelCreator KernelRegistry::GetCreator(const KernelKey &desc) {
int index = GetCreatorFuncIndex(desc); int index = GetCreatorFuncIndex(desc);
if (index >= array_size_) {
if (index >= array_size_ || index < 0) {
MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type " MS_LOG(ERROR) << "invalid kernel key, arch " << desc.arch << ", data_type" << desc.data_type << ",op type "
<< desc.type; << desc.type;
return nullptr; return nullptr;


+ 1
- 1
mindspore/lite/src/runtime/infer_manager.cc View File

@@ -66,7 +66,7 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, std::vector<lite
} }
} }
} else { } else {
TensorC2LiteTensor(out_tensors, outputs);
SetOutputTensorAttr(out_tensors, outputs);
} }


FreeAllTensorC(&in_tensors); FreeAllTensorC(&in_tensors);


Loading…
Cancel
Save