Browse Source

fix switch op infershape bug && adding dataType check in arithmetic op

tags/v1.2.0-rc1
fuzhiye 4 years ago
parent
commit
ce4fe0bcf9
5 changed files with 66 additions and 6 deletions
  1. +31
    -4
      mindspore/lite/src/ops/switch.cc
  2. +18
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc
  3. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h
  4. +15
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  5. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h

+ 31
- 4
mindspore/lite/src/ops/switch.cc View File

@@ -73,6 +73,37 @@ Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);


int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size()); MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size());
for (size_t i = 0; i < outputs_.size() / 2; i++) {
auto *input = inputs_[i + 1];
auto *output_true = outputs_[i];
auto *output_false = outputs_[i + outputs_.size() / 2];
if (input == nullptr) {
MS_LOG(ERROR) << "input tensor is nullptr";
return RET_ERROR;
}
if (output_true == nullptr || output_false == nullptr) {
MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR;
}
output_true->set_data_type(input->data_type());
output_false->set_data_type(input->data_type());
output_true->set_format(input->format());
output_false->set_format(input->format());
auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) {
continue;
} else {
auto input_tensorlist = reinterpret_cast<TensorList *>(input);
auto output_true_tensorlist = reinterpret_cast<TensorList *>(output_true);
auto output_false_tensorlist = reinterpret_cast<TensorList *>(output_false);
output_true_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_false_tensorlist->set_element_shape(input_tensorlist->element_shape());
output_true_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_false_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num());
output_true_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
output_false_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type());
}
}
if (!infer_flag()) { if (!infer_flag()) {
return RET_INFER_INVALID; return RET_INFER_INVALID;
} }
@@ -88,12 +119,8 @@ int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
MS_LOG(ERROR) << "output tensor is nullptr"; MS_LOG(ERROR) << "output tensor is nullptr";
return RET_ERROR; return RET_ERROR;
} }
output_true->set_data_type(input->data_type());
output_false->set_data_type(input->data_type());
output_true->set_shape(input->shape()); output_true->set_shape(input->shape());
output_false->set_shape(input->shape()); output_false->set_shape(input->shape());
output_true->set_format(input->format());
output_false->set_format(input->format());
auto data_type = input->data_type(); auto data_type = input->data_type();
if (data_type != kObjectTypeTensorType) { if (data_type != kObjectTypeTensorType) {
continue; continue;


+ 18
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc View File

@@ -118,7 +118,24 @@ void ArithmeticFP16CPUKernel::InitParam() {
return; return;
} }


int ArithmeticFP16CPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type();
if ((in0_dataType != kNumberTypeFloat16 && in0_dataType != kNumberTypeFloat32) ||
(in1_dataType != kNumberTypeFloat16 && in1_dataType != kNumberTypeFloat32)) {
MS_LOG(ERROR)
<< "The dataTypes of input tensor0 and input tensor1 should be any of float16 and float32, otherwise got error.";
return RET_ERROR;
}
return RET_OK;
}

int ArithmeticFP16CPUKernel::ReSize() { int ArithmeticFP16CPUKernel::ReSize() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticFP16CPUKernel resize failed.";
return RET_ERROR;
}

InitParam(); InitParam();


if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
@@ -131,6 +148,7 @@ int ArithmeticFP16CPUKernel::ReSize() {
MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!";
return RET_ERROR; return RET_ERROR;
} }

if (param_->broadcasting_) { if (param_->broadcasting_) {
outside_ = 1; outside_ = 1;
for (int i = param_->ndim_ - 1; i >= 0; --i) { for (int i = param_->ndim_ - 1; i >= 0; --i) {


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h View File

@@ -46,6 +46,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
int Init() override; int Init() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int CheckDataType();
int DoArithmetic(int task_id); int DoArithmetic(int task_id);
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count, int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
int out_thread_stride); int out_thread_stride);


+ 15
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -46,7 +46,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
* and all need-broadcast-node are const * and all need-broadcast-node are const
* broadcast in resize */ * broadcast in resize */


if (arithmeticParameter_->broadcasting_ == false) {
if (!arithmeticParameter_->broadcasting_) {
return RET_OK; return RET_OK;
} }


@@ -183,7 +183,21 @@ void ArithmeticCPUKernel::InitParam() {
return; return;
} }


int ArithmeticCPUKernel::CheckDataType() {
auto in0_dataType = in_tensors_.at(0)->data_type();
auto in1_dataType = in_tensors_.at(1)->data_type();
if (in0_dataType != in1_dataType) {
MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same.";
return RET_ERROR;
}
return RET_OK;
}

int ArithmeticCPUKernel::ReSize() { int ArithmeticCPUKernel::ReSize() {
if (CheckDataType() != RET_OK) {
MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed.";
return RET_ERROR;
}
InitParam(); InitParam();
return InitBroadCastCase(); return InitBroadCastCase();
} }


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h View File

@@ -80,9 +80,9 @@ class ArithmeticCPUKernel : public LiteKernel {


private: private:
void InitRunFunction(); void InitRunFunction();
void InitOptRunFunction();
void InitParam(); void InitParam();
void FreeTmpPtr(); void FreeTmpPtr();
int CheckDataType();
int InitBroadCastCase(); int InitBroadCastCase();
void InitParamInRunTime(); void InitParamInRunTime();
bool CanBatchScalar(); bool CanBatchScalar();


Loading…
Cancel
Save