From b20597459b0fc738baf21fbd1879295cb1555733 Mon Sep 17 00:00:00 2001 From: yefeng Date: Thu, 15 Apr 2021 21:57:37 +0800 Subject: [PATCH] infer --- .../cpu/nnacl/infer/merge_infer.c | 24 +++++++++++++++++++ .../cpu/nnacl/infer/switch_infer.c | 8 +++---- .../nnacl/infer/tensorlist_fromtensor_infer.c | 3 ++- .../nnacl/infer/tensorlist_setitem_infer.c | 1 + .../cpu/nnacl/infer/tensorlist_stack_infer.c | 9 ++++--- mindspore/lite/src/common/tensor_util.cc | 5 ++++ 6 files changed, 39 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/merge_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/merge_infer.c index 4fa2b68fab..276c503c7a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/merge_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/merge_infer.c @@ -35,6 +35,29 @@ int MergeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t o return NNACL_OK; } +void MergeDataTypeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) { + for (size_t i = 0; i < outputs_size; i++) { + if (inputs[i]->data_type_ == kObjectTypeTensorType) { + TensorListC *input_tensor_list = (TensorListC *)inputs[i]; + if (input_tensor_list->tensors_data_type_ != kTypeUnknown) { + outputs[i] = inputs[i]; + inputs[i] = NULL; + } else { + outputs[i] = inputs[i + outputs_size]; + inputs[i + outputs_size] = NULL; + } + } else { + if (inputs[i]->data_type_ != kTypeUnknown) { + outputs[i] = inputs[i]; + inputs[i] = NULL; + } else { + outputs[i] = inputs[i + outputs_size]; + inputs[i + outputs_size] = NULL; + } + } + } +} + int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { #ifdef Debug @@ -49,6 +72,7 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC ** #endif if (!parameter->infer_flag_) { + MergeDataTypeInfer((struct TensorC **)inputs, inputs_size, outputs, outputs_size); return NNACL_INFER_INVALID; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/switch_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/switch_infer.c index 3636329dde..8672e3f785 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/switch_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/switch_infer.c @@ -31,10 +31,6 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * } #endif - if (!parameter->infer_flag_) { - return NNACL_INFER_INVALID; - } - for (size_t i = 0; i < outputs_size / 2; i++) { outputs[i] = (TensorC *)inputs[i + 1]; if (inputs[i + 1]->data_type_ == kObjectTypeTensorType) { @@ -63,7 +59,9 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * } *((const TensorC **)inputs + i + 1) = NULL; } - + if (!parameter->infer_flag_) { + return NNACL_INFER_INVALID; + } return NNACL_OK; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c index e4b6c1c84a..f85875dce3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c @@ -27,13 +27,14 @@ int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_s #endif TensorListC *output = (TensorListC *)(outputs[0]); + const TensorC *input0 = inputs[0]; output->data_type_ = kObjectTypeTensorType; output->format_ = Format_NHWC; + output->tensors_data_type_ = input0->data_type_; if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - const TensorC *input0 = inputs[0]; if (input0->shape_size_ < 1) { return NNACL_ERR; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c index c764082b16..0a32bda4c7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c @@ -49,6 +49,7 @@ int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size TensorListC *output0 = (TensorListC *)(outputs[0]); output0->data_type_ = input0->data_type_; output0->format_ = input0->format_; + output0->tensors_data_type_ = value_tensor->data_type_; if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c index 361e64e576..3402dd1f1d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c @@ -25,11 +25,13 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, return check_ret; } #endif - + TensorC *output = outputs[0]; + TensorListC *input0 = (TensorListC *)(inputs[0]); + output->data_type_ = input0->tensors_data_type_; + output->format_ = input0->format_; if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - TensorListC *input0 = (TensorListC *)(inputs[0]); if (input0->element_num_ == 0) { return NNACL_ERR; } @@ -63,9 +65,6 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, } } } - TensorC *output = outputs[0]; - output->data_type_ = input0->tensors_data_type_; - output->format_ = input0->format_; ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); SetShapeArray(output, output_shape, output_shape_size); return NNACL_OK; diff --git a/mindspore/lite/src/common/tensor_util.cc b/mindspore/lite/src/common/tensor_util.cc index f654c37c9f..63e808e7c2 100644 --- a/mindspore/lite/src/common/tensor_util.cc +++ b/mindspore/lite/src/common/tensor_util.cc @@ -67,6 +67,11 @@ void SetOutputTensorAttr(const std::vector &tensors_in, std::vectorat(i)->set_format(static_cast(tensors_in[i]->format_)); tensors_out->at(i)->set_data_type(static_cast(tensors_in[i]->data_type_)); tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_}); + if (tensors_in.at(i)->data_type_ == TypeIdC::kObjectTypeTensorType) { + auto tensor_list_in = reinterpret_cast(tensors_in.at(i)); + auto tensor_list_out = reinterpret_cast(tensors_out->at(i)); + tensor_list_out->set_tensors_data_type(TypeId(tensor_list_in->tensors_data_type_)); + } } } }