|
|
|
@@ -388,6 +388,20 @@ int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_siz |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int CommonInferShapeWithTwoInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, |
|
|
|
size_t outputs_size, OpParameter *parameter) { |
|
|
|
int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); |
|
|
|
if (ret != NNACL_OK) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
SetDataTypeFormat(outputs[0], inputs[0]); |
|
|
|
if (!InferFlag(inputs, inputs_size)) { |
|
|
|
return NNACL_INFER_INVALID; |
|
|
|
} |
|
|
|
SetShapeTensor(outputs[0], inputs[0]); |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, |
|
|
|
OpParameter *parameter) { |
|
|
|
if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { |
|
|
|
|