| @@ -234,17 +234,21 @@ int Benchmark::CompareOutput() { | |||||
| MS_LOG(ERROR) << "Cannot find output node: " << nodeName.c_str() << " , compare output data fail."; | MS_LOG(ERROR) << "Cannot find output node: " << nodeName.c_str() << " , compare output data fail."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| for (auto tensor : tensors) { | |||||
| MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); | |||||
| MS_ASSERT(tensor->GetData() != nullptr); | |||||
| float bias = CompareData(nodeName, tensor->shape(), static_cast<float *>(tensor->MutableData())); | |||||
| if (bias >= 0) { | |||||
| totalBias += bias; | |||||
| totalSize++; | |||||
| } else { | |||||
| hasError = true; | |||||
| break; | |||||
| } | |||||
| // make sure tensor size is 1 | |||||
| if (tensors.size() != 1) { | |||||
| MS_LOG(ERROR) << "Only support 1 tensor with a name now."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto &tensor = tensors.front(); | |||||
| MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); | |||||
| MS_ASSERT(tensor->GetData() != nullptr); | |||||
| float bias = CompareData(nodeName, tensor->shape(), static_cast<float *>(tensor->MutableData())); | |||||
| if (bias >= 0) { | |||||
| totalBias += bias; | |||||
| totalSize++; | |||||
| } else { | |||||
| hasError = true; | |||||
| break; | |||||
| } | } | ||||
| } | } | ||||