| @@ -73,6 +73,37 @@ Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator); | |||||
| int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size()); | MS_ASSERT(2 * (inputs_.size() - 1) == outputs_.size()); | ||||
| for (size_t i = 0; i < outputs_.size() / 2; i++) { | |||||
| auto *input = inputs_[i + 1]; | |||||
| auto *output_true = outputs_[i]; | |||||
| auto *output_false = outputs_[i + outputs_.size() / 2]; | |||||
| if (input == nullptr) { | |||||
| MS_LOG(ERROR) << "input tensor is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (output_true == nullptr || output_false == nullptr) { | |||||
| MS_LOG(ERROR) << "output tensor is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| output_true->set_data_type(input->data_type()); | |||||
| output_false->set_data_type(input->data_type()); | |||||
| output_true->set_format(input->format()); | |||||
| output_false->set_format(input->format()); | |||||
| auto data_type = input->data_type(); | |||||
| if (data_type != kObjectTypeTensorType) { | |||||
| continue; | |||||
| } else { | |||||
| auto input_tensorlist = reinterpret_cast<TensorList *>(input); | |||||
| auto output_true_tensorlist = reinterpret_cast<TensorList *>(output_true); | |||||
| auto output_false_tensorlist = reinterpret_cast<TensorList *>(output_false); | |||||
| output_true_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||||
| output_false_tensorlist->set_element_shape(input_tensorlist->element_shape()); | |||||
| output_true_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||||
| output_false_tensorlist->set_max_elements_num(input_tensorlist->max_elements_num()); | |||||
| output_true_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||||
| output_false_tensorlist->set_tensors_data_type(input_tensorlist->tensors_data_type()); | |||||
| } | |||||
| } | |||||
| if (!infer_flag()) { | if (!infer_flag()) { | ||||
| return RET_INFER_INVALID; | return RET_INFER_INVALID; | ||||
| } | } | ||||
| @@ -88,12 +119,8 @@ int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||||
| MS_LOG(ERROR) << "output tensor is nullptr"; | MS_LOG(ERROR) << "output tensor is nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| output_true->set_data_type(input->data_type()); | |||||
| output_false->set_data_type(input->data_type()); | |||||
| output_true->set_shape(input->shape()); | output_true->set_shape(input->shape()); | ||||
| output_false->set_shape(input->shape()); | output_false->set_shape(input->shape()); | ||||
| output_true->set_format(input->format()); | |||||
| output_false->set_format(input->format()); | |||||
| auto data_type = input->data_type(); | auto data_type = input->data_type(); | ||||
| if (data_type != kObjectTypeTensorType) { | if (data_type != kObjectTypeTensorType) { | ||||
| continue; | continue; | ||||
| @@ -118,7 +118,24 @@ void ArithmeticFP16CPUKernel::InitParam() { | |||||
| return; | return; | ||||
| } | } | ||||
| int ArithmeticFP16CPUKernel::CheckDataType() { | |||||
| auto in0_dataType = in_tensors_.at(0)->data_type(); | |||||
| auto in1_dataType = in_tensors_.at(1)->data_type(); | |||||
| if ((in0_dataType != kNumberTypeFloat16 && in0_dataType != kNumberTypeFloat32) || | |||||
| (in1_dataType != kNumberTypeFloat16 && in1_dataType != kNumberTypeFloat32)) { | |||||
| MS_LOG(ERROR) | |||||
| << "The dataTypes of input tensor0 and input tensor1 should be any of float16 and float32, otherwise got error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ArithmeticFP16CPUKernel::ReSize() { | int ArithmeticFP16CPUKernel::ReSize() { | ||||
| if (CheckDataType() != RET_OK) { | |||||
| MS_LOG(ERROR) << "ArithmeticFP16CPUKernel resize failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| InitParam(); | InitParam(); | ||||
| if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { | if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { | ||||
| @@ -131,6 +148,7 @@ int ArithmeticFP16CPUKernel::ReSize() { | |||||
| MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; | MS_LOG(ERROR) << "arithmetic_opt_func_ and arithmetic_func_ function is nullptr!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (param_->broadcasting_) { | if (param_->broadcasting_) { | ||||
| outside_ = 1; | outside_ = 1; | ||||
| for (int i = param_->ndim_ - 1; i >= 0; --i) { | for (int i = param_->ndim_ - 1; i >= 0; --i) { | ||||
| @@ -46,6 +46,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel { | |||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int CheckDataType(); | |||||
| int DoArithmetic(int task_id); | int DoArithmetic(int task_id); | ||||
| int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count, | int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count, | ||||
| int out_thread_stride); | int out_thread_stride); | ||||
| @@ -46,7 +46,7 @@ int ArithmeticCPUKernel::InitBroadCastCase() { | |||||
| * and all need-broadcast-node are const | * and all need-broadcast-node are const | ||||
| * broadcast in resize */ | * broadcast in resize */ | ||||
| if (arithmeticParameter_->broadcasting_ == false) { | |||||
| if (!arithmeticParameter_->broadcasting_) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -183,7 +183,21 @@ void ArithmeticCPUKernel::InitParam() { | |||||
| return; | return; | ||||
| } | } | ||||
| int ArithmeticCPUKernel::CheckDataType() { | |||||
| auto in0_dataType = in_tensors_.at(0)->data_type(); | |||||
| auto in1_dataType = in_tensors_.at(1)->data_type(); | |||||
| if (in0_dataType != in1_dataType) { | |||||
| MS_LOG(ERROR) << "The dataTypes of input tensor0 and input tensor1 should be the same."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ArithmeticCPUKernel::ReSize() { | int ArithmeticCPUKernel::ReSize() { | ||||
| if (CheckDataType() != RET_OK) { | |||||
| MS_LOG(ERROR) << "ArithmeticCPUKernel resize failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| InitParam(); | InitParam(); | ||||
| return InitBroadCastCase(); | return InitBroadCastCase(); | ||||
| } | } | ||||
| @@ -80,9 +80,9 @@ class ArithmeticCPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| void InitRunFunction(); | void InitRunFunction(); | ||||
| void InitOptRunFunction(); | |||||
| void InitParam(); | void InitParam(); | ||||
| void FreeTmpPtr(); | void FreeTmpPtr(); | ||||
| int CheckDataType(); | |||||
| int InitBroadCastCase(); | int InitBroadCastCase(); | ||||
| void InitParamInRunTime(); | void InitParamInRunTime(); | ||||
| bool CanBatchScalar(); | bool CanBatchScalar(); | ||||