| @@ -35,6 +35,29 @@ int MergeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t o | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| void MergeDataTypeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) { | |||||
| for (size_t i = 0; i < outputs_size; i++) { | |||||
| if (inputs[i]->data_type_ == kObjectTypeTensorType) { | |||||
| TensorListC *input_tensor_list = (TensorListC *)inputs[i]; | |||||
| if (input_tensor_list->tensors_data_type_ != kTypeUnknown) { | |||||
| outputs[i] = inputs[i]; | |||||
| inputs[i] = NULL; | |||||
| } else { | |||||
| outputs[i] = inputs[i + outputs_size]; | |||||
| inputs[i + outputs_size] = NULL; | |||||
| } | |||||
| } else { | |||||
| if (inputs[i]->data_type_ != kTypeUnknown) { | |||||
| outputs[i] = inputs[i]; | |||||
| inputs[i] = NULL; | |||||
| } else { | |||||
| outputs[i] = inputs[i + outputs_size]; | |||||
| inputs[i + outputs_size] = NULL; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | ||||
| OpParameter *parameter) { | OpParameter *parameter) { | ||||
| #ifdef Debug | #ifdef Debug | ||||
| @@ -49,6 +72,7 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC ** | |||||
| #endif | #endif | ||||
| if (!parameter->infer_flag_) { | if (!parameter->infer_flag_) { | ||||
| MergeDataTypeInfer((struct TensorC **)inputs, inputs_size, outputs, outputs_size); | |||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| @@ -31,10 +31,6 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * | |||||
| } | } | ||||
| #endif | #endif | ||||
| if (!parameter->infer_flag_) { | |||||
| return NNACL_INFER_INVALID; | |||||
| } | |||||
| for (size_t i = 0; i < outputs_size / 2; i++) { | for (size_t i = 0; i < outputs_size / 2; i++) { | ||||
| outputs[i] = (TensorC *)inputs[i + 1]; | outputs[i] = (TensorC *)inputs[i + 1]; | ||||
| if (inputs[i + 1]->data_type_ == kObjectTypeTensorType) { | if (inputs[i + 1]->data_type_ == kObjectTypeTensorType) { | ||||
| @@ -63,7 +59,9 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * | |||||
| } | } | ||||
| *((const TensorC **)inputs + i + 1) = NULL; | *((const TensorC **)inputs + i + 1) = NULL; | ||||
| } | } | ||||
| if (!parameter->infer_flag_) { | |||||
| return NNACL_INFER_INVALID; | |||||
| } | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -27,13 +27,14 @@ int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_s | |||||
| #endif | #endif | ||||
| TensorListC *output = (TensorListC *)(outputs[0]); | TensorListC *output = (TensorListC *)(outputs[0]); | ||||
| const TensorC *input0 = inputs[0]; | |||||
| output->data_type_ = kObjectTypeTensorType; | output->data_type_ = kObjectTypeTensorType; | ||||
| output->format_ = Format_NHWC; | output->format_ = Format_NHWC; | ||||
| output->tensors_data_type_ = input0->data_type_; | |||||
| if (!parameter->infer_flag_) { | if (!parameter->infer_flag_) { | ||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| const TensorC *input0 = inputs[0]; | |||||
| if (input0->shape_size_ < 1) { | if (input0->shape_size_ < 1) { | ||||
| return NNACL_ERR; | return NNACL_ERR; | ||||
| @@ -49,6 +49,7 @@ int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size | |||||
| TensorListC *output0 = (TensorListC *)(outputs[0]); | TensorListC *output0 = (TensorListC *)(outputs[0]); | ||||
| output0->data_type_ = input0->data_type_; | output0->data_type_ = input0->data_type_; | ||||
| output0->format_ = input0->format_; | output0->format_ = input0->format_; | ||||
| output0->tensors_data_type_ = value_tensor->data_type_; | |||||
| if (!parameter->infer_flag_) { | if (!parameter->infer_flag_) { | ||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| @@ -25,11 +25,13 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, | |||||
| return check_ret; | return check_ret; | ||||
| } | } | ||||
| #endif | #endif | ||||
| TensorC *output = outputs[0]; | |||||
| TensorListC *input0 = (TensorListC *)(inputs[0]); | |||||
| output->data_type_ = input0->tensors_data_type_; | |||||
| output->format_ = input0->format_; | |||||
| if (!parameter->infer_flag_) { | if (!parameter->infer_flag_) { | ||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| TensorListC *input0 = (TensorListC *)(inputs[0]); | |||||
| if (input0->element_num_ == 0) { | if (input0->element_num_ == 0) { | ||||
| return NNACL_ERR; | return NNACL_ERR; | ||||
| } | } | ||||
| @@ -63,9 +65,6 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| TensorC *output = outputs[0]; | |||||
| output->data_type_ = input0->tensors_data_type_; | |||||
| output->format_ = input0->format_; | |||||
| ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); | ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); | ||||
| SetShapeArray(output, output_shape, output_shape_size); | SetShapeArray(output, output_shape, output_shape_size); | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| @@ -67,6 +67,11 @@ void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<l | |||||
| tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_)); | tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_)); | ||||
| tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_)); | tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_)); | ||||
| tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_}); | tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_}); | ||||
| if (tensors_in.at(i)->data_type_ == TypeIdC::kObjectTypeTensorType) { | |||||
| auto tensor_list_in = reinterpret_cast<TensorListC *>(tensors_in.at(i)); | |||||
| auto tensor_list_out = reinterpret_cast<TensorList *>(tensors_out->at(i)); | |||||
| tensor_list_out->set_tensors_data_type(TypeId(tensor_list_in->tensors_data_type_)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||