diff --git a/mindspore/lite/src/ops/switch.cc b/mindspore/lite/src/ops/switch.cc index eacbd2cf7e..0e08a17d08 100644 --- a/mindspore/lite/src/ops/switch.cc +++ b/mindspore/lite/src/ops/switch.cc @@ -73,6 +73,37 @@ Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator); int Switch::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size()); + for (size_t i = 0; i < outputs_.size() / 2; i++) { + auto *input = inputs_[i + 1]; + auto *output_true = outputs_[i]; + auto *output_false = outputs_[i + outputs_.size() / 2]; + if (input == nullptr) { + MS_LOG(ERROR) << "input tensor is nullptr"; + return RET_ERROR; + } + if (output_true == nullptr || output_false == nullptr) { + MS_LOG(ERROR) << "output tensor is nullptr"; + return RET_ERROR; + } + output_true->set_data_type(input->data_type()); + output_false->set_data_type(input->data_type()); + output_true->set_format(input->format()); + output_false->set_format(input->format()); + auto data_type = input->data_type(); + if (data_type != kObjectTypeTensorType) { + continue; + } else { + auto input_tensorlist = reinterpret_cast(input); + auto output_true_tensorlist = reinterpret_cast(output_true); + auto output_false_tensorlist = reinterpret_cast(output_false); + output_true_tensorlist->set_element_shape(input_tensorlist->element_shape()); + output_false_tensorlist->set_element_shape(input_tensorlist->element_shape()); + output_true_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); + output_false_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); + output_true_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); + output_false_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); + } + } if (!infer_flag()) { return RET_INFER_INVALID; } @@ -88,12 +119,8 @@ int Switch::InferShape(std::vector inputs_, std::vector outp MS_LOG(ERROR) << "output tensor is nullptr"; return RET_ERROR; } - output_true->set_data_type(input->data_type()); - output_false->set_data_type(input->data_type()); output_true->set_shape(input->shape()); output_false->set_shape(input->shape()); - output_true->set_format(input->format()); - output_false->set_format(input->format()); auto data_type = input->data_type(); if (data_type != kObjectTypeTensorType) { continue; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 888bb8b91d..c73ea19b7b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -118,7 +118,24 @@ void ArithmeticFP16CPUKernel::InitParam() { return; } +int ArithmeticFP16CPUKernel::CheckDataType() { + auto in0_dataType = in_tensors_.at(0)->data_type(); + auto in1_dataType = in_tensors_.at(1)->data_type(); + if ((in0_dataType != kNumberTypeFloat16 && in0_dataType != kNumberTypeFloat32) || + (in1_dataType != kNumberTypeFloat16 && in1_dataType != kNumberTypeFloat32)) { + MS_LOG(ERROR) + << "The dataTypes of input tensor0 and input tensor1 should be any of float16 and float32, otherwise got error."; + return RET_ERROR; + } + return RET_OK; +} + int ArithmeticFP16CPUKernel::ReSize() { + if (CheckDataType() != RET_OK) { + MS_LOG(ERROR) << "ArithmeticFP16CPUKernel resize failed."; + return RET_ERROR; + } + InitParam(); if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { @@ -131,6 +148,7 @@ int ArithmeticFP16CPUKernel::ReSize() { MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; return RET_ERROR; } + if (param_->broadcasting_) { outside_ = 1; for (int i = param_->ndim_ - 1; i >= 0; --i) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index 5e95858747..a36de76573 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -46,6 +46,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel { int Init() override; int ReSize() override; int Run() override; + int CheckDataType(); int DoArithmetic(int task_id); int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count, int out_thread_stride); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 51db2c8152..042b2675bc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -46,7 +46,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() { * and all need-broadcast-node are const * broadcast in resize */ - if (arithmeticParameter_->broadcasting_ == false) { + if (!arithmeticParameter_->broadcasting_) { return RET_OK; } @@ -183,7 +183,21 @@ void ArithmeticCPUKernel::InitParam() { return; } +int ArithmeticCPUKernel::CheckDataType() { + auto in0_dataType = in_tensors_.at(0)->data_type(); + auto in1_dataType = in_tensors_.at(1)->data_type(); + if (in0_dataType != in1_dataType) { + MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same."; + return RET_ERROR; + } + return RET_OK; +} + int ArithmeticCPUKernel::ReSize() { + if (CheckDataType() != RET_OK) { + MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed."; + return RET_ERROR; + } InitParam(); return InitBroadCastCase(); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index c5c0842c53..11155bd99c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -80,9 +80,9 @@ class ArithmeticCPUKernel : public LiteKernel { private: void InitRunFunction(); - void InitOptRunFunction(); void InitParam(); void FreeTmpPtr(); + int CheckDataType(); int InitBroadCastCase(); void InitParamInRunTime(); bool CanBatchScalar();