|
|
|
@@ -35,6 +35,29 @@ int MergeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t o |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void MergeDataTypeInfer(TensorC **inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size) { |
|
|
|
for (size_t i = 0; i < outputs_size; i++) { |
|
|
|
if (inputs[i]->data_type_ == kObjectTypeTensorType) { |
|
|
|
TensorListC *input_tensor_list = (TensorListC *)inputs[i]; |
|
|
|
if (input_tensor_list->tensors_data_type_ != kTypeUnknown) { |
|
|
|
outputs[i] = inputs[i]; |
|
|
|
inputs[i] = NULL; |
|
|
|
} else { |
|
|
|
outputs[i] = inputs[i + outputs_size]; |
|
|
|
inputs[i + outputs_size] = NULL; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (inputs[i]->data_type_ != kTypeUnknown) { |
|
|
|
outputs[i] = inputs[i]; |
|
|
|
inputs[i] = NULL; |
|
|
|
} else { |
|
|
|
outputs[i] = inputs[i + outputs_size]; |
|
|
|
inputs[i + outputs_size] = NULL; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, |
|
|
|
OpParameter *parameter) { |
|
|
|
#ifdef Debug |
|
|
|
@@ -49,6 +72,7 @@ int MergeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC ** |
|
|
|
#endif |
|
|
|
|
|
|
|
if (!parameter->infer_flag_) { |
|
|
|
MergeDataTypeInfer((struct TensorC **)inputs, inputs_size, outputs, outputs_size); |
|
|
|
return NNACL_INFER_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
|