|
|
|
@@ -17,6 +17,40 @@ |
|
|
|
#include "nnacl/infer/arithmetic_infer.h" |
|
|
|
#include "nnacl/infer/infer_register.h" |
|
|
|
|
|
|
|
void UpdateInputShape(const int input_shape0_size, const int input_shape1_size, int *ndim, const int *input_shape0, |
|
|
|
const int *input_shape1, int *in_shape0, int *in_shape1) { |
|
|
|
if (input_shape0_size < input_shape1_size) { |
|
|
|
*ndim = input_shape1_size; |
|
|
|
int fill_dim_num = input_shape1_size - input_shape0_size; |
|
|
|
int j = 0; |
|
|
|
for (size_t i = 0; i < input_shape1_size; i++) { |
|
|
|
if (i < fill_dim_num) { |
|
|
|
in_shape0[i] = 1; |
|
|
|
} else { |
|
|
|
in_shape0[i] = input_shape0[j++]; |
|
|
|
} |
|
|
|
in_shape1[i] = input_shape1[i]; |
|
|
|
} |
|
|
|
} else if (input_shape0_size > input_shape1_size) { |
|
|
|
*ndim = input_shape0_size; |
|
|
|
int fill_dim_num = input_shape0_size - input_shape1_size; |
|
|
|
int j = 0; |
|
|
|
for (size_t i = 0; i < input_shape0_size; i++) { |
|
|
|
if (i < fill_dim_num) { |
|
|
|
in_shape1[i] = 1; |
|
|
|
} else { |
|
|
|
in_shape1[i] = input_shape1[j++]; |
|
|
|
} |
|
|
|
in_shape0[i] = input_shape0[i]; |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < input_shape0_size; i++) { |
|
|
|
in_shape1[i] = input_shape1[i]; |
|
|
|
in_shape0[i] = input_shape0[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, |
|
|
|
OpParameter *parameter) { |
|
|
|
#ifdef Debug |
|
|
|
@@ -46,75 +80,48 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso |
|
|
|
if (!parameter->infer_flag_) { |
|
|
|
return NNACL_INFER_INVALID; |
|
|
|
} |
|
|
|
if (input_shape0_size > 10 || input_shape1_size > 10) { |
|
|
|
if (input_shape0_size >= MAX_SHAPE_SIZE || input_shape1_size >= MAX_SHAPE_SIZE) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
int in_shape0_[10]; |
|
|
|
int in_shape1_[10]; |
|
|
|
int out_shape_[10]; |
|
|
|
int in_shape0[10]; |
|
|
|
int in_shape1[10]; |
|
|
|
int out_shape[10]; |
|
|
|
int ndim = input_shape0_size; |
|
|
|
UpdateInputShape(input_shape0_size, input_shape1_size, &ndim, input_shape0, input_shape1, in_shape0, in_shape1); |
|
|
|
|
|
|
|
int ndim_ = input_shape0_size; |
|
|
|
if (input_shape0_size < input_shape1_size) { |
|
|
|
ndim_ = input_shape1_size; |
|
|
|
int fill_dim_num = input_shape1_size - input_shape0_size; |
|
|
|
int j = 0; |
|
|
|
for (size_t i = 0; i < input_shape1_size; i++) { |
|
|
|
if (i < fill_dim_num) { |
|
|
|
in_shape0_[i] = 1; |
|
|
|
} else { |
|
|
|
in_shape0_[i] = input_shape0[j++]; |
|
|
|
} |
|
|
|
in_shape1_[i] = input_shape1[i]; |
|
|
|
} |
|
|
|
} else if (input_shape0_size > input_shape1_size) { |
|
|
|
ndim_ = input_shape0_size; |
|
|
|
int fill_dim_num = input_shape0_size - input_shape1_size; |
|
|
|
int j = 0; |
|
|
|
for (size_t i = 0; i < input_shape0_size; i++) { |
|
|
|
if (i < fill_dim_num) { |
|
|
|
in_shape1_[i] = 1; |
|
|
|
} else { |
|
|
|
in_shape1_[i] = input_shape1[j++]; |
|
|
|
} |
|
|
|
in_shape0_[i] = input_shape0[i]; |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < input_shape0_size; i++) { |
|
|
|
in_shape1_[i] = input_shape1[i]; |
|
|
|
in_shape0_[i] = input_shape0[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int output_shape[MAX_SHAPE_SIZE]; |
|
|
|
int output_shape[MAX_SHAPE_SIZE] = {0}; |
|
|
|
size_t output_shape_size = 0; |
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
if (in_shape0_[i] != in_shape1_[i]) { |
|
|
|
if (in_shape0_[i] == 1) { |
|
|
|
out_shape_[i] = in_shape1_[i]; |
|
|
|
} else if (in_shape1_[i] == 1) { |
|
|
|
out_shape_[i] = in_shape0_[i]; |
|
|
|
if (ndim >= MAX_SHAPE_SIZE) { |
|
|
|
return NNACL_INFER_INVALID; |
|
|
|
} |
|
|
|
for (int i = 0; i < ndim; i++) { |
|
|
|
if (in_shape0[i] != in_shape1[i]) { |
|
|
|
if (in_shape0[i] == 1) { |
|
|
|
out_shape[i] = in_shape1[i]; |
|
|
|
} else if (in_shape1[i] == 1) { |
|
|
|
out_shape[i] = in_shape0[i]; |
|
|
|
} else { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
param->broadcasting_ = true; |
|
|
|
} else { |
|
|
|
out_shape_[i] = in_shape0_[i]; |
|
|
|
out_shape[i] = in_shape0[i]; |
|
|
|
} |
|
|
|
output_shape[output_shape_size] = out_shape_[i]; |
|
|
|
output_shape[output_shape_size] = out_shape[i]; |
|
|
|
output_shape_size++; |
|
|
|
} |
|
|
|
|
|
|
|
SetShapeArray(output, output_shape, output_shape_size); |
|
|
|
|
|
|
|
param->ndim_ = ndim_; |
|
|
|
memcpy(param->in_shape0_, in_shape0_, ndim_ * sizeof(int)); |
|
|
|
memcpy(param->in_shape1_, in_shape1_, ndim_ * sizeof(int)); |
|
|
|
memcpy(param->out_shape_, out_shape_, ndim_ * sizeof(int)); |
|
|
|
param->ndim_ = ndim; |
|
|
|
memcpy(param->in_shape0_, in_shape0, ndim * sizeof(int)); |
|
|
|
memcpy(param->in_shape1_, in_shape1, ndim * sizeof(int)); |
|
|
|
memcpy(param->out_shape_, out_shape, ndim * sizeof(int)); |
|
|
|
|
|
|
|
param->in_elements_num0_ = 1; |
|
|
|
param->in_elements_num1_ = 1; |
|
|
|
param->out_elements_num_ = 1; |
|
|
|
for (int i = 0; i < ndim_; i++) { |
|
|
|
for (int i = 0; i < ndim; i++) { |
|
|
|
param->in_elements_num0_ *= param->in_shape0_[i]; |
|
|
|
param->in_elements_num1_ *= param->in_shape1_[i]; |
|
|
|
param->out_elements_num_ *= param->out_shape_[i]; |
|
|
|
|