Browse Source

infer

pull/15249/head
yefeng 4 years ago
parent
commit
b20597459b
6 changed files with 39 additions and 11 deletions
  1. +24
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/merge_infer.c
  2. +3
    -5
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/switch_infer.c
  3. +2
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c
  4. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c
  5. +4
    -5
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c
  6. +5
    -0
      mindspore/lite/src/common/tensor_util.cc

+ 24
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/merge_infer.c View File

@@ -35,6 +35,29 @@ int MergeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t o
return NNACL_OK;
}

void MergeDataTypeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) {
for (size_t i = 0; i < outputs_size; i++) {
if (inputs[i]->data_type_ == kObjectTypeTensorType) {
TensorListC *input_tensor_list = (TensorListC *)inputs[i];
if (input_tensor_list->tensors_data_type_ != kTypeUnknown) {
outputs[i] = inputs[i];
inputs[i] = NULL;
} else {
outputs[i] = inputs[i + outputs_size];
inputs[i + outputs_size] = NULL;
}
} else {
if (inputs[i]->data_type_ != kTypeUnknown) {
outputs[i] = inputs[i];
inputs[i] = NULL;
} else {
outputs[i] = inputs[i + outputs_size];
inputs[i + outputs_size] = NULL;
}
}
}
}

int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
#ifdef Debug
@@ -49,6 +72,7 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
#endif

if (!parameter->infer_flag_) {
MergeDataTypeInfer((struct TensorC **)inputs, inputs_size, outputs, outputs_size);
return NNACL_INFER_INVALID;
}



+ 3
- 5
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/switch_infer.c View File

@@ -31,10 +31,6 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
#endif

if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}

for (size_t i = 0; i < outputs_size / 2; i++) {
outputs[i] = (TensorC *)inputs[i + 1];
if (inputs[i + 1]->data_type_ == kObjectTypeTensorType) {
@@ -63,7 +59,9 @@ int SwitchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
*((const TensorC **)inputs + i + 1) = NULL;
}

if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
return NNACL_OK;
}



+ 2
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_fromtensor_infer.c View File

@@ -27,13 +27,14 @@ int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_s
#endif

TensorListC *output = (TensorListC *)(outputs[0]);
const TensorC *input0 = inputs[0];
output->data_type_ = kObjectTypeTensorType;
output->format_ = Format_NHWC;
output->tensors_data_type_ = input0->data_type_;

if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
const TensorC *input0 = inputs[0];

if (input0->shape_size_ < 1) {
return NNACL_ERR;


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_setitem_infer.c View File

@@ -49,6 +49,7 @@ int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size
TensorListC *output0 = (TensorListC *)(outputs[0]);
output0->data_type_ = input0->data_type_;
output0->format_ = input0->format_;
output0->tensors_data_type_ = value_tensor->data_type_;

if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;


+ 4
- 5
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c View File

@@ -25,11 +25,13 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size,
return check_ret;
}
#endif

TensorC *output = outputs[0];
TensorListC *input0 = (TensorListC *)(inputs[0]);
output->data_type_ = input0->tensors_data_type_;
output->format_ = input0->format_;
if (!parameter->infer_flag_) {
return NNACL_INFER_INVALID;
}
TensorListC *input0 = (TensorListC *)(inputs[0]);
if (input0->element_num_ == 0) {
return NNACL_ERR;
}
@@ -63,9 +65,6 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size,
}
}
}
TensorC *output = outputs[0];
output->data_type_ = input0->tensors_data_type_;
output->format_ = input0->format_;
ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_);
SetShapeArray(output, output_shape, output_shape_size);
return NNACL_OK;


+ 5
- 0
mindspore/lite/src/common/tensor_util.cc View File

@@ -67,6 +67,11 @@ void SetOutputTensorAttr(const std::vector<TensorC *> &tensors_in, std::vector<l
tensors_out->at(i)->set_format(static_cast<schema::Format>(tensors_in[i]->format_));
tensors_out->at(i)->set_data_type(static_cast<TypeId>(tensors_in[i]->data_type_));
tensors_out->at(i)->set_shape({tensors_in[i]->shape_, tensors_in[i]->shape_ + tensors_in[i]->shape_size_});
if (tensors_in.at(i)->data_type_ == TypeIdC::kObjectTypeTensorType) {
auto tensor_list_in = reinterpret_cast<TensorListC *>(tensors_in.at(i));
auto tensor_list_out = reinterpret_cast<TensorList *>(tensors_out->at(i));
tensor_list_out->set_tensors_data_type(TypeId(tensor_list_in->tensors_data_type_));
}
}
}
}


Loading…
Cancel
Save