|
|
|
@@ -50,6 +50,17 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_ |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int CheckMatMulBias(int *shape, size_t dim_size) { |
|
|
|
if (dim_size > 1) { |
|
|
|
for (size_t i = 0; i < dim_size - 1; i++) { |
|
|
|
if (shape[i] != DIMENSION_1D) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, |
|
|
|
OpParameter *parameter) { |
|
|
|
TensorC *input0 = (TensorC *)inputs[0]; |
|
|
|
@@ -67,7 +78,7 @@ int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs |
|
|
|
if (inputs_size == kInputSize2) { |
|
|
|
TensorC *bias = (TensorC *)inputs[2]; |
|
|
|
ShapeSet(bias_shape, &bias_shape_size, bias->shape_, bias->shape_size_); |
|
|
|
MS_CHECK_TRUE_RET(bias_shape_size <= DIMENSION_1D, NNACL_ERR); |
|
|
|
MS_CHECK_TRUE_RET(CheckMatMulBias(bias_shape, bias_shape_size) == NNACL_OK, NNACL_ERR); |
|
|
|
} |
|
|
|
|
|
|
|
if (a_shape_size == COMM_SHAPE_SIZE && a_shape[THIRD_INPUT] == 1 && a_shape[FOURTH_INPUT] == 1) { |
|
|
|
|