| @@ -73,6 +73,11 @@ build_lite_x86_64_jni_and_jar() { | |||
| cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ | |||
| cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ | |||
| cp ./libmindspore-lite-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/ | |||
| if [[ "X$is_train" = "Xon" ]]; then | |||
| cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/linux_x86/libs/ | |||
| cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/linux_x86/ | |||
| cp ./libmindspore-lite-train-jni.so ${BASEPATH}/output/tmp/${pkg_name}/runtime/lib/ | |||
| fi | |||
| cd ${LITE_JAVA_PATH}/java | |||
| rm -rf gradle .gradle gradlew gradlew.bat | |||
| @@ -256,6 +261,10 @@ build_lite_arm64_and_jni() { | |||
| fi | |||
| cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ | |||
| cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ | |||
| if [[ "X$is_train" = "Xon" ]]; then | |||
| cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ | |||
| cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ | |||
| fi | |||
| } | |||
| build_lite_arm32_and_jni() { | |||
| @@ -296,6 +305,10 @@ build_lite_arm32_and_jni() { | |||
| fi | |||
| cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||
| cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ | |||
| if [[ "X$is_train" = "Xon" ]]; then | |||
| cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||
| cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ | |||
| fi | |||
| } | |||
| build_aar() { | |||
| @@ -18,6 +18,7 @@ package com.mindspore.lite.train_lenet; | |||
| import com.mindspore.lite.MSTensor; | |||
| import com.mindspore.lite.LiteSession; | |||
| import com.mindspore.lite.TrainSession; | |||
| import com.mindspore.lite.config.MSConfig; | |||
| import java.nio.ByteBuffer; | |||
| @@ -48,7 +49,7 @@ public class NetRunner { | |||
| msConfig.init(0, 2, 0, false); | |||
| session = new LiteSession(); | |||
| System.out.println("Model path is " + modelPath); | |||
| session = session.createTrainSession(modelPath, msConfig, false); | |||
| session = TrainSession.createTrainSession(modelPath, msConfig, false); | |||
| session.setupVirtualBatch(virtualBatch, 0.01f, 1.00f); | |||
| List<MSTensor> inputs = session.getInputs(); | |||
| @@ -63,20 +63,14 @@ public class LiteSession { | |||
| } | |||
| } | |||
| public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { | |||
| LiteSession liteSession = new LiteSession(); | |||
| liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); | |||
| if (liteSession.sessionPtr == 0) { | |||
| return null; | |||
| } else { | |||
| return liteSession; | |||
| } | |||
| } | |||
| public long getSessionPtr() { | |||
| return sessionPtr; | |||
| } | |||
| public void setSessionPtr(long sessionPtr) { | |||
| this.sessionPtr = sessionPtr; | |||
| } | |||
| public void bindThread(boolean ifBind) { | |||
| this.bindThread(this.sessionPtr, ifBind); | |||
| } | |||
| @@ -204,8 +198,6 @@ public class LiteSession { | |||
| private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr); | |||
| private native long createTrainSession(String filename, long msContextPtr, boolean trainMode, long msTrainCfgPtr); | |||
| private native boolean compileGraph(long sessionPtr, long modelPtr); | |||
| private native void bindThread(long sessionPtr, boolean ifBind); | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.lite; | |||
| import com.mindspore.lite.LiteSession; | |||
| import com.mindspore.lite.config.MSConfig; | |||
| public class TrainSession { | |||
| static { | |||
| System.loadLibrary("mindspore-lite-train-jni"); | |||
| } | |||
| public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { | |||
| LiteSession liteSession = new LiteSession(); | |||
| long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); | |||
| if (sessionPtr == 0) { | |||
| return null; | |||
| } else { | |||
| liteSession.setSessionPtr(sessionPtr); | |||
| return liteSession; | |||
| } | |||
| } | |||
| private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, | |||
| long msTrainCfgPtr); | |||
| } | |||
| @@ -18,6 +18,7 @@ package com.mindspore.flclient.model; | |||
| import com.mindspore.flclient.Common; | |||
| import com.mindspore.lite.LiteSession; | |||
| import com.mindspore.lite.TrainSession; | |||
| import com.mindspore.lite.MSTensor; | |||
| import com.mindspore.lite.config.MSConfig; | |||
| import mindspore.schema.FeatureMap; | |||
| @@ -84,7 +85,7 @@ public class SessionUtil { | |||
| // arg 2: cpuBindMode:NO_BIND -> 0 | |||
| // arg 3: enable_fp16 -> false | |||
| msConfig.init(0, 1, 0, false); | |||
| LiteSession trainSession = LiteSession.createTrainSession(modelPath, msConfig,false); | |||
| LiteSession trainSession = TrainSession.createTrainSession(modelPath, msConfig,false); | |||
| if (trainSession == null) { | |||
| logger.severe(Common.addTag("init session failed,please check model path:" + modelPath)); | |||
| return null; | |||
| @@ -63,20 +63,14 @@ public class LiteSession { | |||
| } | |||
| } | |||
| public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { | |||
| LiteSession liteSession = new LiteSession(); | |||
| liteSession.sessionPtr = liteSession.createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); | |||
| if (liteSession.sessionPtr == 0) { | |||
| return null; | |||
| } else { | |||
| return liteSession; | |||
| } | |||
| } | |||
| public long getSessionPtr() { | |||
| return sessionPtr; | |||
| } | |||
| public void setSessionPtr(long sessionPtr) { | |||
| this.sessionPtr = sessionPtr; | |||
| } | |||
| public void bindThread(boolean ifBind) { | |||
| this.bindThread(this.sessionPtr, ifBind); | |||
| } | |||
| @@ -204,8 +198,6 @@ public class LiteSession { | |||
| private native long createSessionWithModel(MappedByteBuffer buffer, long msConfigPtr); | |||
| private native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, long msTrainCfgPtr); | |||
| private native boolean compileGraph(long sessionPtr, long modelPtr); | |||
| private native void bindThread(long sessionPtr, boolean ifBind); | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.lite; | |||
| import com.mindspore.lite.LiteSession; | |||
| import com.mindspore.lite.config.MSConfig; | |||
| public class TrainSession { | |||
| static { | |||
| System.loadLibrary("mindspore-lite-train-jni"); | |||
| } | |||
| public static LiteSession createTrainSession(String modelName, final MSConfig config, boolean trainMode) { | |||
| LiteSession liteSession = new LiteSession(); | |||
| long sessionPtr = createTrainSession(modelName, config.getMSConfigPtr(), trainMode, 0); | |||
| if (sessionPtr == 0) { | |||
| return null; | |||
| } else { | |||
| liteSession.setSessionPtr(sessionPtr); | |||
| return liteSession; | |||
| } | |||
| } | |||
| private static native long createTrainSession(String fileName, long msContextPtr, boolean trainMode, | |||
| long msTrainCfgPtr); | |||
| } | |||
| @@ -92,12 +92,6 @@ set(JNI_SRC | |||
| set(LITE_SO_NAME mindspore-lite) | |||
| if(SUPPORT_TRAIN) | |||
| set(JNI_SRC | |||
| ${JNI_SRC} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp | |||
| ) | |||
| endif() | |||
| add_library(mindspore-lite-jni SHARED ${JNI_SRC}) | |||
| if(PLATFORM_ARM64 OR PLATFORM_ARM32) | |||
| @@ -108,13 +102,15 @@ else() | |||
| endif() | |||
| if(SUPPORT_TRAIN) | |||
| set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite) | |||
| if(PLATFORM_ARM64 OR PLATFORM_ARM32) | |||
| find_library(log-lib log) | |||
| target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME} ${log-lib}) | |||
| else() | |||
| target_link_libraries(mindspore-lite-jni ${LITE_TRAIN_SO_NAME}) | |||
| endif() | |||
| set(LITE_TRAIN_SO_NAME mindspore-lite-train minddata-lite) | |||
| set(JNI_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp) | |||
| add_library(mindspore-lite-train-jni SHARED ${JNI_TRAIN_SRC}) | |||
| if(PLATFORM_ARM64 OR PLATFORM_ARM32) | |||
| find_library(log-lib log) | |||
| target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME} ${log-lib}) | |||
| else() | |||
| target_link_libraries(mindspore-lite-train-jni ${LITE_TRAIN_SO_NAME}) | |||
| endif() | |||
| endif() | |||
| set(NDK_STRIP | |||
| @@ -20,7 +20,7 @@ | |||
| #include "include/train/train_cfg.h" | |||
| #include "include/errorcode.h" | |||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createTrainSession(JNIEnv *env, jobject thiz, | |||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createTrainSession(JNIEnv *env, jobject thiz, | |||
| jstring file_name, | |||
| jlong ms_context_ptr, | |||
| jboolean train_mode, | |||