From: @yeyunpeng2020 Reviewed-by: Signed-off-by:pull/13927/MERGE
| @@ -107,7 +107,7 @@ public class Main { | |||||
| return false; | return false; | ||||
| } | } | ||||
| msgSb.append(" and out data:"); | msgSb.append(" and out data:"); | ||||
| for (int i = 0; i < 10 && i < outTensor.elementsNum(); i++) { | |||||
| for (int i = 0; i < 50 && i < outTensor.elementsNum(); i++) { | |||||
| msgSb.append(" ").append(result[i]); | msgSb.append(" ").append(result[i]); | ||||
| } | } | ||||
| System.out.println(msgSb.toString()); | System.out.println(msgSb.toString()); | ||||
| @@ -16,7 +16,6 @@ | |||||
| #include <jni.h> | #include <jni.h> | ||||
| #include "common/ms_log.h" | #include "common/ms_log.h" | ||||
| #include "common/jni_utils.h" | |||||
| #include "include/lite_session.h" | #include "include/lite_session.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -111,7 +110,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputs | |||||
| return jlong(nullptr); | return jlong(nullptr); | ||||
| } | } | ||||
| auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | ||||
| auto input = lite_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name)); | |||||
| auto input = lite_session_ptr->GetInputsByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE)); | |||||
| return jlong(input); | return jlong(input); | ||||
| } | } | ||||
| @@ -131,10 +130,11 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | ||||
| auto inputs = lite_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name)); | |||||
| auto inputs = lite_session_ptr->GetOutputsByNodeName(env->GetStringUTFChars(node_name, JNI_FALSE)); | |||||
| for (auto input : inputs) { | for (auto input : inputs) { | ||||
| jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input)); | jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input)); | ||||
| env->CallBooleanMethod(ret, array_list_add, tensor_addr); | env->CallBooleanMethod(ret, array_list_add, tensor_addr); | ||||
| env->DeleteLocalRef(tensor_addr); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -155,11 +155,12 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp | |||||
| auto outputs = lite_session_ptr->GetOutputs(); | auto outputs = lite_session_ptr->GetOutputs(); | ||||
| jclass long_object = env->FindClass("java/lang/Long"); | jclass long_object = env->FindClass("java/lang/Long"); | ||||
| jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V"); | jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V"); | ||||
| for (auto output_iter : outputs) { | |||||
| for (const auto &output_iter : outputs) { | |||||
| auto node_name = output_iter.first; | auto node_name = output_iter.first; | ||||
| auto ms_tensor = output_iter.second; | auto ms_tensor = output_iter.second; | ||||
| jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor)); | jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor)); | ||||
| env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr); | env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr); | ||||
| env->DeleteLocalRef(tensor_addr); | |||||
| } | } | ||||
| return hash_map; | return hash_map; | ||||
| } | } | ||||
| @@ -178,7 +179,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp | |||||
| } | } | ||||
| auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | ||||
| auto output_names = lite_session_ptr->GetOutputTensorNames(); | auto output_names = lite_session_ptr->GetOutputTensorNames(); | ||||
| for (auto output_name : output_names) { | |||||
| for (const auto &output_name : output_names) { | |||||
| env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str())); | env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str())); | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| @@ -193,7 +194,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutput | |||||
| return jlong(nullptr); | return jlong(nullptr); | ||||
| } | } | ||||
| auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | ||||
| auto output = lite_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name)); | |||||
| auto output = lite_session_ptr->GetOutputByTensorName(env->GetStringUTFChars(tensor_name, JNI_FALSE)); | |||||
| return jlong(output); | return jlong(output); | ||||
| } | } | ||||
| @@ -219,7 +220,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_resize | |||||
| } | } | ||||
| auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | ||||
| jsize input_size = static_cast<int>(env->GetArrayLength(inputs)); | |||||
| auto input_size = static_cast<int>(env->GetArrayLength(inputs)); | |||||
| jlong *input_data = env->GetLongArrayElements(inputs, nullptr); | jlong *input_data = env->GetLongArrayElements(inputs, nullptr); | ||||
| std::vector<mindspore::tensor::MSTensor *> c_inputs; | std::vector<mindspore::tensor::MSTensor *> c_inputs; | ||||
| for (int i = 0; i < input_size; i++) { | for (int i = 0; i < input_size; i++) { | ||||
| @@ -231,16 +232,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_resize | |||||
| auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer); | auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(tensor_pointer); | ||||
| c_inputs.push_back(ms_tensor_ptr); | c_inputs.push_back(ms_tensor_ptr); | ||||
| } | } | ||||
| jsize tensor_size = static_cast<int>(env->GetArrayLength(dims)); | |||||
| auto tensor_size = static_cast<int>(env->GetArrayLength(dims)); | |||||
| for (int i = 0; i < tensor_size; i++) { | for (int i = 0; i < tensor_size; i++) { | ||||
| jintArray array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i)); | |||||
| jsize dim_size = static_cast<int>(env->GetArrayLength(array)); | |||||
| auto array = static_cast<jintArray>(env->GetObjectArrayElement(dims, i)); | |||||
| auto dim_size = static_cast<int>(env->GetArrayLength(array)); | |||||
| jint *dim_data = env->GetIntArrayElements(array, nullptr); | jint *dim_data = env->GetIntArrayElements(array, nullptr); | ||||
| std::vector<int> tensor_dims; | |||||
| std::vector<int> tensor_dims(dim_size); | |||||
| for (int j = 0; j < dim_size; j++) { | for (int j = 0; j < dim_size; j++) { | ||||
| tensor_dims.push_back(dim_data[j]); | |||||
| tensor_dims[j] = dim_data[j]; | |||||
| } | } | ||||
| c_dims.push_back(tensor_dims); | c_dims.push_back(tensor_dims); | ||||
| env->ReleaseIntArrayElements(array, dim_data, JNI_ABORT); | |||||
| env->DeleteLocalRef(array); | |||||
| } | } | ||||
| int ret = lite_session_ptr->Resize(c_inputs, c_dims); | int ret = lite_session_ptr->Resize(c_inputs, c_dims); | ||||
| return (jboolean)(ret == mindspore::lite::RET_OK); | return (jboolean)(ret == mindspore::lite::RET_OK); | ||||
| @@ -17,7 +17,6 @@ | |||||
| #include <jni.h> | #include <jni.h> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include "common/ms_log.h" | #include "common/ms_log.h" | ||||
| #include "common/jni_utils.h" | |||||
| #include "include/model.h" | #include "include/model.h" | ||||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEnv *env, jobject thiz, jobject buffer) { | extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEnv *env, jobject thiz, jobject buffer) { | ||||
| @@ -38,7 +37,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEn | |||||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath(JNIEnv *env, jobject thiz, | extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath(JNIEnv *env, jobject thiz, | ||||
| jstring model_path) { | jstring model_path) { | ||||
| auto model_path_char = JstringToChar(env, model_path); | |||||
| auto model_path_char = env->GetStringUTFChars(model_path, JNI_FALSE); | |||||
| if (nullptr == model_path_char) { | if (nullptr == model_path_char) { | ||||
| MS_LOGE("model_path_char is nullptr"); | MS_LOGE("model_path_char is nullptr"); | ||||
| return reinterpret_cast<jlong>(nullptr); | return reinterpret_cast<jlong>(nullptr); | ||||
| @@ -56,7 +55,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath | |||||
| ifs.seekg(0, std::ios::end); | ifs.seekg(0, std::ios::end); | ||||
| auto size = ifs.tellg(); | auto size = ifs.tellg(); | ||||
| std::unique_ptr<char[]> buf(new (std::nothrow) char[size]); | |||||
| auto buf = new (std::nothrow) char[size]; | |||||
| if (buf == nullptr) { | if (buf == nullptr) { | ||||
| MS_LOGE("malloc buf failed, file: %s", model_path_char); | MS_LOGE("malloc buf failed, file: %s", model_path_char); | ||||
| ifs.close(); | ifs.close(); | ||||
| @@ -64,10 +63,10 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath | |||||
| } | } | ||||
| ifs.seekg(0, std::ios::beg); | ifs.seekg(0, std::ios::beg); | ||||
| ifs.read(buf.get(), size); | |||||
| ifs.read(buf, size); | |||||
| ifs.close(); | ifs.close(); | ||||
| delete[](model_path_char); | |||||
| auto model = mindspore::lite::Model::Import(buf.get(), size); | |||||
| auto model = mindspore::lite::Model::Import(buf, size); | |||||
| delete[] buf; | |||||
| if (model == nullptr) { | if (model == nullptr) { | ||||
| MS_LOGE("Import model failed"); | MS_LOGE("Import model failed"); | ||||
| return reinterpret_cast<jlong>(nullptr); | return reinterpret_cast<jlong>(nullptr); | ||||
| @@ -2021,6 +2021,38 @@ function Run_npu() { | |||||
| done < ${models_npu_config} | done < ${models_npu_config} | ||||
| } | } | ||||
| # Run on x86 java platform: | |||||
| function Run_x86_java() { | |||||
| cd ${x86_java_path} || exit 1 | |||||
| tar -zxf mindspore-lite-${version}-inference-linux-x64-jar.tar.gz || exit 1 | |||||
| export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${x86_java_path}/mindspore-lite-${version}-inference-linux-x64-jar/jar | |||||
| # compile benchmark | |||||
| echo "javac -cp ${x86_java_path}/mindspore-lite-${version}-inference-linux-x64-jar/jar/mindspore-lite-java.jar ${basepath}/st/java/src/main/java/Benchmark.java -d ." | |||||
| javac -cp ${x86_java_path}/mindspore-lite-${version}-inference-linux-x64-jar/jar/mindspore-lite-java.jar ${basepath}/st/java/src/main/java/Benchmark.java -d . | |||||
| count=0 | |||||
| # Run tflite converted models: | |||||
| while read line; do | |||||
| # only run top5. | |||||
| count=`expr ${count}+1` | |||||
| if [[ ${count} -gt 5 ]]; then | |||||
| break | |||||
| fi | |||||
| model_name=${line} | |||||
| if [[ $model_name == \#* ]]; then | |||||
| continue | |||||
| fi | |||||
| echo ${model_name} >> "${run_x86_java_log_file}" | |||||
| echo "java -classpath .:${x86_java_path}/mindspore-lite-${version}-inference-linux-x64-jar/jar/mindspore-lite-java.jar Benchmark ${ms_models_path}/${model_name}.ms /home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin /home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out 1" >> "${run_x86_java_log_file}" | |||||
| java -classpath .:${x86_java_path}/mindspore-lite-${version}-inference-linux-x64-jar/jar/mindspore-lite-java.jar Benchmark ${ms_models_path}/${model_name}.ms /home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin /home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out 1 | |||||
| if [ $? = 0 ]; then | |||||
| run_result='x86_java: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||||
| else | |||||
| run_result='x86_java: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||||
| fi | |||||
| done < ${models_tflite_config} | |||||
| } | |||||
| # Print start msg before run testcase | # Print start msg before run testcase | ||||
| function MS_PRINT_TESTCASE_START_MSG() { | function MS_PRINT_TESTCASE_START_MSG() { | ||||
| echo "" | echo "" | ||||
| @@ -2158,6 +2190,9 @@ echo 'run x86 sse logs: ' > ${run_x86_sse_log_file} | |||||
| run_x86_avx_log_file=${basepath}/run_x86_avx_log.txt | run_x86_avx_log_file=${basepath}/run_x86_avx_log.txt | ||||
| echo 'run x86 avx logs: ' > ${run_x86_avx_log_file} | echo 'run x86 avx logs: ' > ${run_x86_avx_log_file} | ||||
| run_x86_java_log_file=${basepath}/run_x86_java_log.txt | |||||
| echo 'run x86 java logs: ' > ${run_x86_java_log_file} | |||||
| run_arm64_fp32_log_file=${basepath}/run_arm64_fp32_log.txt | run_arm64_fp32_log_file=${basepath}/run_arm64_fp32_log.txt | ||||
| echo 'run arm64_fp32 logs: ' > ${run_arm64_fp32_log_file} | echo 'run arm64_fp32 logs: ' > ${run_arm64_fp32_log_file} | ||||
| @@ -2207,6 +2242,15 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86-avx" ]]; th | |||||
| sleep 1 | sleep 1 | ||||
| fi | fi | ||||
| if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86-java" ]]; then | |||||
| # Run on x86-java | |||||
| echo "start Run x86 java ..." | |||||
| x86_java_path=${release_path}/aar | |||||
| Run_x86_java`` & | |||||
| Run_x86_java_PID=$! | |||||
| sleep 1 | |||||
| fi | |||||
| if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp32" ]]; then | if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp32" ]]; then | ||||
| # Run on arm64 | # Run on arm64 | ||||
| arm64_path=${release_path}/android_aarch64 | arm64_path=${release_path}/android_aarch64 | ||||
| @@ -2308,6 +2352,16 @@ if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86" ]]; then | |||||
| isFailed=1 | isFailed=1 | ||||
| fi | fi | ||||
| fi | fi | ||||
| if [[ $backend == "all" || $backend == "x86-all" || $backend == "x86-java" ]]; then | |||||
| wait ${Run_x86_java_PID} | |||||
| Run_x86_java_status=$? | |||||
| if [[ ${Run_x86_java_status} != 0 ]];then | |||||
| echo "Run_x86 java failed" | |||||
| cat ${run_x86_java_log_file} | |||||
| isFailed=1 | |||||
| fi | |||||
| fi | |||||
| if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp32" ]]; then | if [[ $backend == "all" || $backend == "arm_cpu" || $backend == "arm64_fp32" ]]; then | ||||
| if [[ ${Run_arm64_fp32_status} != 0 ]];then | if [[ ${Run_arm64_fp32_status} != 0 ]];then | ||||
| @@ -0,0 +1,52 @@ | |||||
| <?xml version="1.0" encoding="UTF-8"?> | |||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | |||||
| xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | |||||
| xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | |||||
| <modelVersion>4.0.0</modelVersion> | |||||
| <groupId>com.mindspore.lite</groupId> | |||||
| <artifactId>mindspore-lite-java-test</artifactId> | |||||
| <version>1.0</version> | |||||
| <dependencies> | |||||
| <dependency> | |||||
| <groupId>com.mindspore.lite</groupId> | |||||
| <artifactId>mindspore-lite-java</artifactId> | |||||
| <version>1.0</version> | |||||
| <scope>system</scope> | |||||
| <systemPath>${project.basedir}/lib/mindspore-lite-java.jar</systemPath> | |||||
| </dependency> | |||||
| </dependencies> | |||||
| <properties> | |||||
| <maven.compiler.source>8</maven.compiler.source> | |||||
| <maven.compiler.target>8</maven.compiler.target> | |||||
| </properties> | |||||
| <build> | |||||
| <finalName>${project.name}</finalName> | |||||
| <plugins> | |||||
| <plugin> | |||||
| <groupId>org.apache.maven.plugins</groupId> | |||||
| <artifactId>maven-assembly-plugin</artifactId> | |||||
| <configuration> | |||||
| <archive> | |||||
| <manifest> | |||||
| <mainClass>Benchmark</mainClass> | |||||
| </manifest> | |||||
| </archive> | |||||
| <descriptorRefs> | |||||
| <descriptorRef>jar-with-dependencies</descriptorRef> | |||||
| </descriptorRefs> | |||||
| </configuration> | |||||
| <executions> | |||||
| <execution> | |||||
| <id>make-assemble</id> | |||||
| <phase>package</phase> | |||||
| <goals> | |||||
| <goal>single</goal> | |||||
| </goals> | |||||
| </execution> | |||||
| </executions> | |||||
| </plugin> | |||||
| </plugins> | |||||
| </build> | |||||
| </project> | |||||
| @@ -0,0 +1,175 @@ | |||||
| import com.mindspore.lite.DataType; | |||||
| import com.mindspore.lite.LiteSession; | |||||
| import com.mindspore.lite.MSTensor; | |||||
| import com.mindspore.lite.Model; | |||||
| import com.mindspore.lite.config.DeviceType; | |||||
| import com.mindspore.lite.config.MSConfig; | |||||
| import java.io.*; | |||||
| public class Benchmark { | |||||
| private static Model model; | |||||
| private static LiteSession session; | |||||
| public static byte[] readBinFile(String fileName, int size) { | |||||
| try { | |||||
| DataInputStream is = new DataInputStream( | |||||
| new BufferedInputStream(new FileInputStream( | |||||
| fileName))); | |||||
| byte[] buf = new byte[size]; | |||||
| is.read(buf); | |||||
| is.close(); | |||||
| return buf; | |||||
| } catch (IOException e) { | |||||
| e.printStackTrace(); | |||||
| } | |||||
| return null; | |||||
| } | |||||
| public static boolean compareData(String filePath, float accuracy) { | |||||
| double meanError = 0; | |||||
| File file = new File(filePath); | |||||
| if (file.exists()) { | |||||
| try { | |||||
| FileReader fileReader = new FileReader(file); | |||||
| BufferedReader br = new BufferedReader(fileReader); | |||||
| String lineContent = null; | |||||
| int line = 0; | |||||
| MSTensor outTensor = null; | |||||
| String name = null; | |||||
| while ((lineContent = br.readLine()) != null) { | |||||
| String[] strings = lineContent.split(" "); | |||||
| if (line++ % 2 == 0) { | |||||
| name = strings[0]; | |||||
| outTensor = session.getOutputByTensorName(name); | |||||
| continue; | |||||
| } | |||||
| float[] benchmarkData = new float[strings.length]; | |||||
| for (int i = 0; i < strings.length; i++) { | |||||
| benchmarkData[i] = Float.parseFloat(strings[i]); | |||||
| } | |||||
| float[] outData = outTensor.getFloatData(); | |||||
| int errorCount = 0; | |||||
| for (int i = 0; i < benchmarkData.length; i++) { | |||||
| double relativeTolerance = 1e-5; | |||||
| double absoluteTolerance = 1e-8; | |||||
| double tolerance = absoluteTolerance + relativeTolerance * Math.abs(benchmarkData[i]); | |||||
| double absoluteError = Math.abs(outData[i] - benchmarkData[i]); | |||||
| if (absoluteError > tolerance) { | |||||
| if (Math.abs(benchmarkData[i] - 0.0f) < Float.MIN_VALUE) | |||||
| if (absoluteError > 1e-5) { | |||||
| meanError += absoluteError; | |||||
| errorCount++; | |||||
| } else { | |||||
| continue; | |||||
| } | |||||
| } else { | |||||
| meanError += absoluteError / (Math.abs(benchmarkData[i]) + Float.MIN_VALUE); | |||||
| errorCount++; | |||||
| } | |||||
| } | |||||
| if (meanError > 0.0f) { | |||||
| meanError /= errorCount; | |||||
| } | |||||
| if (meanError <= 0.0000001) { | |||||
| System.out.println("Mean bias of node/tensor " + name + " : 0%"); | |||||
| } else { | |||||
| System.out.println("Mean bias of node/tensor " + name + " : " + meanError * 100 + "%"); | |||||
| } | |||||
| } | |||||
| br.close(); | |||||
| fileReader.close(); | |||||
| } catch (IOException e) { | |||||
| e.printStackTrace(); | |||||
| } | |||||
| } | |||||
| return meanError < accuracy; | |||||
| } | |||||
| private static boolean compile() { | |||||
| MSConfig msConfig = new MSConfig(); | |||||
| boolean ret = msConfig.init(DeviceType.DT_CPU, 2); | |||||
| if (!ret) { | |||||
| System.err.println("Init context failed"); | |||||
| return false; | |||||
| } | |||||
| // Create the MindSpore lite session. | |||||
| session = new LiteSession(); | |||||
| ret = session.init(msConfig); | |||||
| msConfig.free(); | |||||
| if (!ret) { | |||||
| System.err.println("Create session failed"); | |||||
| model.free(); | |||||
| return false; | |||||
| } | |||||
| // Compile graph. | |||||
| ret = session.compileGraph(model); | |||||
| if (!ret) { | |||||
| System.err.println("Compile graph failed"); | |||||
| model.free(); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| private static void freeBuffer() { | |||||
| session.free(); | |||||
| model.free(); | |||||
| } | |||||
| public static void main(String[] args) { | |||||
| if (args.length < 4) { | |||||
| System.err.println("We must pass parameters such as modelPath, inDataFile, benchmarkDataFile and accuracy."); | |||||
| return; | |||||
| } | |||||
| String modelPath = args[0]; | |||||
| String[] inDataFile = args[1].split(","); | |||||
| String benchmarkDataFile = args[2]; | |||||
| float accuracy = Float.parseFloat(args[3]); | |||||
| model = new Model(); | |||||
| boolean ret = model.loadModel(modelPath); | |||||
| if (!ret) { | |||||
| System.err.println("Load model failed, model path is " + modelPath); | |||||
| return; | |||||
| } | |||||
| ret = compile(); | |||||
| if (!ret) { | |||||
| System.err.println("MindSpore Lite compile failed."); | |||||
| return; | |||||
| } | |||||
| for (int i = 0; i < session.getInputs().size(); i++) { | |||||
| MSTensor inputTensor = session.getInputs().get(i); | |||||
| if (inputTensor.getDataType() != DataType.kNumberTypeFloat32) { | |||||
| System.err.println("Input tensor shape do not float, the data type is " + inputTensor.getDataType()); | |||||
| freeBuffer(); | |||||
| return; | |||||
| } | |||||
| // Set Input Data. | |||||
| byte[] data = readBinFile(inDataFile[i], (int) inputTensor.size()); | |||||
| inputTensor.setData(data); | |||||
| } | |||||
| // Run Inference. | |||||
| if (!session.runGraph()) { | |||||
| System.err.println("MindSpore Lite run failed."); | |||||
| freeBuffer(); | |||||
| return; | |||||
| } | |||||
| boolean benchmarkResult = compareData(benchmarkDataFile, accuracy); | |||||
| freeBuffer(); | |||||
| if (!benchmarkResult) { | |||||
| System.err.println(modelPath + " accuracy error is too large."); | |||||
| System.exit(1); | |||||
| } | |||||
| } | |||||
| } | |||||