Browse Source

add libmindspore-lite-train-jni.so fix lite inference bug

tags/v1.4.0
zhengjun10 4 years ago
parent
commit
ea788db1dd
9 changed files with 113 additions and 40 deletions
  1. +13
    -0
      mindspore/lite/build_lite.sh
  2. +2
    -1
      mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java
  3. +4
    -12
      mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java
  4. +39
    -0
      mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java
  5. +2
    -1
      mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java
  6. +4
    -12
      mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java
  7. +39
    -0
      mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java
  8. +9
    -13
      mindspore/lite/java/native/CMakeLists.txt
  9. +1
    -1
      mindspore/lite/java/native/runtime/train_session.cpp

+ 13
- 0
mindspore/lite/build_lite.sh View File

@@ -73,6 +73,11 @@ build_lite_x86_64_jni_and_jar() {
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/
cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/
fi

cd ${LITE_JAVA_PATH}/java
rm -rf gradle .gradle gradlew gradlew.bat
@@ -256,6 +261,10 @@ build_lite_arm64_and_jni() {
fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/
fi
}

build_lite_arm32_and_jni() {
@@ -296,6 +305,10 @@ build_lite_arm32_and_jni() {
fi
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
if [[ "X$is_train" = "Xon" ]]; then
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/
fi
}

build_aar() {


+ 2
- 1
mindspore/lite/examples/train_lenet_java/src/main/java/com/mindspore/lite/train_lenet/NetRunner.java View File

@@ -18,6 +18,7 @@ package com.mindspore.lite.train_lenet;

import com.mindspore.lite.MSTensor;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.config.MSConfig;

import java.nio.ByteBuffer;
@@ -48,7 +49,7 @@ public class NetRunner {
msConfig.init(0, 2, 0, false);
session = new LiteSession();
System.out.println("Model path is " + modelPath);
session = session.createTrainSession(modelPath, msConfig, false);
session = TrainSession.createTrainSession(modelPath, msConfig, false);
session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f);

List<MSTensor> inputs = session.getInputs();


+ 4
- 12
mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java View File

@@ -63,20 +63,14 @@ public class LiteSession {
}
}

public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (liteSession.sessionPtr == 0) {
return null;
} else {
return liteSession;
}
}

public long getSessionPtr() {
return sessionPtr;
}

public void setSessionPtr(long sessionPtr) {
this.sessionPtr = sessionPtr;
}

public void bindThread(boolean ifBind) {
this.bindThread(this.sessionPtr, ifBind);
}
@@ -204,8 +198,6 @@ public class LiteSession {

private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);

private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr);

private native boolean compileGraph(long sessionPtr, long modelPtr);

private native void bindThread(long sessionPtr, boolean ifBind);


+ 39
- 0
mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mindspore.lite;

import com.mindspore.lite.LiteSession;
import com.mindspore.lite.config.MSConfig;

public class TrainSession {
static {
System.loadLibrary("mindspore-lite-train-jni");
}
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (sessionPtr == 0) {
return null;
} else {
liteSession.setSessionPtr(sessionPtr);
return liteSession;
}
}

private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
long msTrainCfgPtr);
}

+ 2
- 1
mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/model/SessionUtil.java View File

@@ -18,6 +18,7 @@ package com.mindspore.flclient.model;

import com.mindspore.flclient.Common;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.TrainSession;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.config.MSConfig;
import mindspore.schema.FeatureMap;
@@ -84,7 +85,7 @@ public class SessionUtil {
// arg 2: cpuBindMode:NO_BIND -> 0
// arg 3: enable_fp16 -> false
msConfig.init(0, 1, 0, false);
LiteSession trainSession = LiteSession.createTrainSession(modelPath, msConfig,false);
LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig,false);
if (trainSession == null) {
logger.severe(Common.addTag("init session failed,please check model path:" + modelPath));
return null;


+ 4
- 12
mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/LiteSession.java View File

@@ -63,20 +63,14 @@ public class LiteSession {
}
}

public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (liteSession.sessionPtr == 0) {
return null;
} else {
return liteSession;
}
}

public long getSessionPtr() {
return sessionPtr;
}

public void setSessionPtr(long sessionPtr) {
this.sessionPtr = sessionPtr;
}

public void bindThread(boolean ifBind) {
this.bindThread(this.sessionPtr, ifBind);
}
@@ -204,8 +198,6 @@ public class LiteSession {

private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr);

private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr);

private native boolean compileGraph(long sessionPtr, long modelPtr);

private native void bindThread(long sessionPtr, boolean ifBind);


+ 39
- 0
mindspore/lite/java/java/linux_x86/src/main/java/com.mindspore.lite/TrainSession.java View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mindspore.lite;

import com.mindspore.lite.LiteSession;
import com.mindspore.lite.config.MSConfig;

public class TrainSession {
static {
System.loadLibrary("mindspore-lite-train-jni");
}
public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) {
LiteSession liteSession = new LiteSession();
long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0);
if (sessionPtr == 0) {
return null;
} else {
liteSession.setSessionPtr(sessionPtr);
return liteSession;
}
}

private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode,
long msTrainCfgPtr);
}

+ 9
- 13
mindspore/lite/java/native/CMakeLists.txt View File

@@ -92,12 +92,6 @@ set(JNI_SRC

set(LITE_SO_NAME mindspore-lite)

if(SUPPORT_TRAIN)
set(JNI_SRC
${JNI_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
)
endif()
add_library(mindspore-lite-jni SHARED ${JNI_SRC})

if(PLATFORM_ARM64 OR PLATFORM_ARM32)
@@ -108,13 +102,15 @@ else()
endif()

if(SUPPORT_TRAIN)
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
find_library(log-lib log)
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
else()
target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME})
endif()
set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite)
set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp)
add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC})
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
find_library(log-lib log)
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib})
else()
target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME})
endif()
endif()

set(NDK_STRIP


+ 1
- 1
mindspore/lite/java/native/runtime/train_session.cpp View File

@@ -20,7 +20,7 @@
#include "include/train/train_cfg.h"
#include "include/errorcode.h"

extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz,
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createTrainSession(JNIEnv *env, jobject thiz,
jstring file_name,
jlong ms_context_ptr,
jboolean train_mode,


Loading…
Cancel
Save