|
|
|
@@ -62,6 +62,19 @@ int GetShapeByType(const TensorC *shape_tensor, size_t shape_size, int32_t *dst_ |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int CheckShape(const int *input_shape, const int *dst_shape, const int input_shape_index, const int dst_shape_index) { |
|
|
|
if (dst_shape[dst_shape_index] < 0) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
if (input_shape_index >= 0) { |
|
|
|
int input_shape_i = input_shape[input_shape_index]; |
|
|
|
if (input_shape_i != dst_shape[dst_shape_index] && input_shape_i != 1 && dst_shape[dst_shape_index] != 1) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
} |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, |
|
|
|
OpParameter *parameter) { |
|
|
|
int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); |
|
|
|
@@ -112,21 +125,12 @@ int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, Tens |
|
|
|
size_t input_shape_size = input->shape_size_; |
|
|
|
int shape[MAX_SHAPE_SIZE]; |
|
|
|
int input_shape_index = (int)(input_shape_size)-1; |
|
|
|
if (input_shape_size > dst_shape_size) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
|
|
|
|
for (int i = (int)(dst_shape_size)-1; i >= 0; --i) { |
|
|
|
if (dst_shape[i] < 0) { |
|
|
|
if (CheckShape(input_shape, dst_shape, input_shape_index, i) != NNACL_OK) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
if (input_shape_index >= 0) { |
|
|
|
int dim = input_shape[input_shape_index]; |
|
|
|
if (dim != dst_shape[i] && dim != 1) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
} |
|
|
|
shape[i] = dst_shape[i]; |
|
|
|
shape[i] = dst_shape[i] == 1 ? input_shape[input_shape_index] : dst_shape[i]; |
|
|
|
--input_shape_index; |
|
|
|
} |
|
|
|
SetShapeArray(outputs[0], shape, dst_shape_size); |
|
|
|
|