|
|
|
@@ -28,7 +28,7 @@ void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullCo |
|
|
|
void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; } |
|
|
|
void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; } |
|
|
|
void FullConnection::SetActivationType(int activationType) { |
|
|
|
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType) activationType; |
|
|
|
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType; |
|
|
|
} |
|
|
|
#else |
|
|
|
|
|
|
|
@@ -47,43 +47,58 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_, |
|
|
|
MS_ASSERT(this->primitive != nullptr); |
|
|
|
auto input0 = inputs_.front(); |
|
|
|
MS_ASSERT(input0 != nullptr); |
|
|
|
auto input1 = inputs_.at(1); |
|
|
|
auto input1 = inputs_[1]; |
|
|
|
MS_ASSERT(input1 != nullptr); |
|
|
|
auto output = outputs_.front(); |
|
|
|
MS_ASSERT(output != nullptr); |
|
|
|
output->set_data_type(input0->data_type()); |
|
|
|
output->SetFormat(input0->GetFormat()); |
|
|
|
if (!GetInferFlag()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) { |
|
|
|
MS_LOG(ERROR) << "Input tensors num error"; |
|
|
|
return 1; |
|
|
|
return RET_INPUT_TENSOR_ERROR; |
|
|
|
} |
|
|
|
if (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size())) { |
|
|
|
if (GetUseAxis() && (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size()))) { |
|
|
|
MS_LOG(ERROR) << "FullConnection axis invalid"; |
|
|
|
return 1; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
int new_k = 1; |
|
|
|
for (size_t i = GetAxis(); i < input0->shape().size(); ++i) { |
|
|
|
new_k *= input0->shape().at(i); |
|
|
|
} |
|
|
|
if (new_k != input1->shape().at(1)) { |
|
|
|
MS_LOG(ERROR) << "Input1 size invalid"; |
|
|
|
return 1; |
|
|
|
if (GetUseAxis()) { |
|
|
|
for (int i = GetAxis(); i < input0->shape().size(); ++i) { |
|
|
|
new_k *= input0->shape()[i]; |
|
|
|
} |
|
|
|
if (new_k != input1->shape()[1]) { |
|
|
|
MS_LOG(ERROR) << "Input1 size invalid"; |
|
|
|
return RET_INPUT_TENSOR_ERROR; |
|
|
|
} |
|
|
|
} else { |
|
|
|
new_k = input1->shape()[1]; |
|
|
|
} |
|
|
|
if (GetHasBias()) { |
|
|
|
if (inputs_.at(2)->shape()[0] != input1->shape()[0]) { |
|
|
|
if (inputs_[2]->shape()[0] != input1->shape()[0]) { |
|
|
|
MS_LOG(ERROR) << "bias size invalid"; |
|
|
|
return 1; |
|
|
|
return RET_INPUT_TENSOR_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<int> out_shape{inputs_[0]->shape()}; |
|
|
|
out_shape.resize(GetAxis() + 1); |
|
|
|
out_shape[GetAxis()] = input1->shape()[0]; |
|
|
|
if (GetUseAxis()) { |
|
|
|
out_shape.resize(GetAxis() + 1); |
|
|
|
out_shape[GetAxis()] = input1->shape()[0]; |
|
|
|
} else { |
|
|
|
int total = 1; |
|
|
|
for (int i = 0; i < input0->shape().size(); ++i) { |
|
|
|
total *= input0->shape()[i]; |
|
|
|
} |
|
|
|
out_shape.resize(2); |
|
|
|
auto batch_size = total / new_k; |
|
|
|
out_shape[0] = batch_size; |
|
|
|
out_shape[1] = input1->shape()[0]; |
|
|
|
} |
|
|
|
output->set_shape(out_shape); |
|
|
|
output->set_data_type(input0->data_type()); |
|
|
|
output->SetFormat(input0->GetFormat()); |
|
|
|
|
|
|
|
return 0; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
} // namespace lite |
|
|
|
} // namespace mindspore |