| @@ -671,6 +671,12 @@ build_lite_java_arm64() { | |||||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/arm64-v8a/ | |||||
| else | else | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | ||||
| @@ -697,6 +703,12 @@ build_lite_java_arm32() { | |||||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | |||||
| else | else | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | ||||
| @@ -706,10 +718,15 @@ build_lite_java_arm32() { | |||||
| build_lite_java_x86() { | build_lite_java_x86() { | ||||
| # build mindspore-lite x86 | # build mindspore-lite x86 | ||||
| local inference_or_train=inference | |||||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||||
| inference_or_train=train | |||||
| fi | |||||
| if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then | if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then | ||||
| local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD} | |||||
| local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD} | |||||
| else | else | ||||
| local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64 | |||||
| local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64 | |||||
| fi | fi | ||||
| if [[ "X$INC_BUILD" == "Xoff" ]] || [[ ! -f "${BASEPATH}/mindspore/lite/build/java/${JTARBALL}.tar.gz" ]]; then | if [[ "X$INC_BUILD" == "Xoff" ]] || [[ ! -f "${BASEPATH}/mindspore/lite/build/java/${JTARBALL}.tar.gz" ]]; then | ||||
| build_lite "x86_64" "off" "" | build_lite "x86_64" "off" "" | ||||
| @@ -721,8 +738,20 @@ build_lite_java_x86() { | |||||
| [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/ | [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/ | ||||
| mkdir -p ${JAVA_PATH}/java/linux_x86/libs/ | mkdir -p ${JAVA_PATH}/java/linux_x86/libs/ | ||||
| mkdir -p ${JAVA_PATH}/native/libs/linux_x86/ | mkdir -p ${JAVA_PATH}/native/libs/linux_x86/ | ||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/linux_x86/libs/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/linux_x86/ | |||||
| else | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||||
| fi | |||||
| [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL} | |||||
| } | } | ||||
| build_jni_arm64() { | build_jni_arm64() { | ||||
| @@ -776,7 +805,7 @@ build_jni_x86_64() { | |||||
| mkdir -pv java/jni | mkdir -pv java/jni | ||||
| cd java/jni | cd java/jni | ||||
| cmake -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | cmake -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | ||||
| -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${JAVA_PATH}/native/" | |||||
| -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} "${JAVA_PATH}/native/" | |||||
| make -j$THREAD_NUM | make -j$THREAD_NUM | ||||
| if [[ $? -ne 0 ]]; then | if [[ $? -ne 0 ]]; then | ||||
| echo "---------------- mindspore lite: build jni x86_64 failed----------------" | echo "---------------- mindspore lite: build jni x86_64 failed----------------" | ||||
| @@ -825,11 +854,16 @@ build_java() { | |||||
| cd ${JAVA_PATH}/java/app/build | cd ${JAVA_PATH}/java/app/build | ||||
| zip -r mindspore-lite-maven-${VERSION_STR}.zip mindspore | zip -r mindspore-lite-maven-${VERSION_STR}.zip mindspore | ||||
| local inference_or_train=inference | |||||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||||
| inference_or_train=train | |||||
| fi | |||||
| # build linux x86 jar | # build linux x86 jar | ||||
| if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then | if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then | ||||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}-jar | |||||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}-jar | |||||
| else | else | ||||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-jar | |||||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-jar | |||||
| fi | fi | ||||
| check_java_home | check_java_home | ||||
| build_lite_java_x86 | build_lite_java_x86 | ||||
| @@ -843,15 +877,17 @@ build_java() { | |||||
| gradle releaseJar | gradle releaseJar | ||||
| # install and package | # install and package | ||||
| mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib | mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib | ||||
| cp ${JAVA_PATH}/java/linux_x86/libs/*.so ${JAVA_PATH}/java/linux_x86/build/lib/jar | |||||
| cp ${JAVA_PATH}/java/linux_x86/libs/*.so* ${JAVA_PATH}/java/linux_x86/build/lib/jar | |||||
| cd ${JAVA_PATH}/java/linux_x86/build/ | cd ${JAVA_PATH}/java/linux_x86/build/ | ||||
| cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME} | cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME} | ||||
| tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME} | tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME} | ||||
| # copy output | # copy output | ||||
| cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output | cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output | ||||
| cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output | cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output | ||||
| cd ${BASEPATH}/output | cd ${BASEPATH}/output | ||||
| [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-inference-linux-x64 | |||||
| [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64 | |||||
| exit 0 | exit 0 | ||||
| } | } | ||||
| @@ -0,0 +1,55 @@ | |||||
| <?xml version="1.0" encoding="UTF-8"?> | |||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | |||||
| xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | |||||
| xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | |||||
| <modelVersion>4.0.0</modelVersion> | |||||
| <groupId>com.mindspore.lite.demo</groupId> | |||||
| <artifactId>train_lenet_java</artifactId> | |||||
| <version>1.0</version> | |||||
| <properties> | |||||
| <maven.compiler.source>8</maven.compiler.source> | |||||
| <maven.compiler.target>8</maven.compiler.target> | |||||
| </properties> | |||||
| <dependencies> | |||||
| <dependency> | |||||
| <groupId>com.mindspore.lite</groupId> | |||||
| <artifactId>mindspore-lite-java</artifactId> | |||||
| <version>1.0</version> | |||||
| <scope>system</scope> | |||||
| <systemPath>${project.basedir}/lib/mindspore-lite-java.jar</systemPath> | |||||
| </dependency> | |||||
| </dependencies> | |||||
| <build> | |||||
| <finalName>${project.name}</finalName> | |||||
| <plugins> | |||||
| <plugin> | |||||
| <groupId>org.apache.maven.plugins</groupId> | |||||
| <artifactId>maven-assembly-plugin</artifactId> | |||||
| <configuration> | |||||
| <archive> | |||||
| <manifest> | |||||
| <mainClass>com.mindspore.lite.train_lenet.Main</mainClass> | |||||
| </manifest> | |||||
| </archive> | |||||
| <descriptorRefs> | |||||
| <descriptorRef>jar-with-dependencies</descriptorRef> | |||||
| </descriptorRefs> | |||||
| </configuration> | |||||
| <executions> | |||||
| <execution> | |||||
| <id>make-assemble</id> | |||||
| <phase>package</phase> | |||||
| <goals> | |||||
| <goal>single</goal> | |||||
| </goals> | |||||
| </execution> | |||||
| </executions> | |||||
| </plugin> | |||||
| </plugins> | |||||
| </build> | |||||
| </project> | |||||
| @@ -0,0 +1,131 @@ | |||||
| package com.mindspore.lite.train_lenet; | |||||
| import java.io.BufferedInputStream; | |||||
| import java.io.FileInputStream; | |||||
| import java.io.IOException; | |||||
| import java.util.Vector; | |||||
| public class DataSet { | |||||
| private long numOfClasses = 0; | |||||
| private long expectedDataSize = 0; | |||||
| public class DataLabelTuple { | |||||
| public float[] data; | |||||
| public int label; | |||||
| } | |||||
| Vector<DataLabelTuple> trainData; | |||||
| Vector<DataLabelTuple> testData; | |||||
| public void initializeMNISTDatabase(String dpath) { | |||||
| numOfClasses = 10; | |||||
| trainData = new Vector<DataLabelTuple>(); | |||||
| testData = new Vector<DataLabelTuple>(); | |||||
| readMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath+"/train/train-labels-idx1-ubyte", trainData); | |||||
| readMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath+"/test/t10k-labels-idx1-ubyte", testData); | |||||
| System.out.println("train data cnt: " + trainData.size()); | |||||
| System.out.println("test data cnt: " + testData.size()); | |||||
| } | |||||
| private String bytesToHex(byte[] bytes) { | |||||
| StringBuffer sb = new StringBuffer(); | |||||
| for (int i = 0; i < bytes.length; i++) { | |||||
| String hex = Integer.toHexString(bytes[i] & 0xFF); | |||||
| if (hex.length() < 2) { | |||||
| sb.append(0); | |||||
| } | |||||
| sb.append(hex); | |||||
| } | |||||
| return sb.toString(); | |||||
| } | |||||
| private void readFile(BufferedInputStream inputStream, byte[] bytes, int len) throws IOException { | |||||
| int result = inputStream.read(bytes, 0, len); | |||||
| if (result != len) { | |||||
| System.err.println("expected read " + len + " bytes, but " + result + " read"); | |||||
| System.exit(1); | |||||
| } | |||||
| } | |||||
| public void readMNISTFile(String inputFileName, String labelFileName, Vector<DataLabelTuple> dataset) { | |||||
| try { | |||||
| BufferedInputStream ibin = new BufferedInputStream(new FileInputStream(inputFileName)); | |||||
| BufferedInputStream lbin = new BufferedInputStream(new FileInputStream(labelFileName)); | |||||
| byte[] bytes = new byte[4]; | |||||
| readFile(ibin, bytes, 4); | |||||
| if (!"00000803".equals(bytesToHex(bytes))) { // 2051 | |||||
| System.err.println("The dataset is not valid: " + bytesToHex(bytes)); | |||||
| return; | |||||
| } | |||||
| readFile(ibin, bytes, 4); | |||||
| int inumber = Integer.parseInt(bytesToHex(bytes), 16); | |||||
| readFile(lbin, bytes, 4); | |||||
| if (!"00000801".equals(bytesToHex(bytes))) { // 2049 | |||||
| System.err.println("The dataset label is not valid: " + bytesToHex(bytes)); | |||||
| return; | |||||
| } | |||||
| readFile(lbin, bytes, 4); | |||||
| int lnumber = Integer.parseInt(bytesToHex(bytes), 16); | |||||
| if (inumber != lnumber) { | |||||
| System.err.println("input data cnt: " + inumber + " not equal label cnt: " + lnumber); | |||||
| return; | |||||
| } | |||||
| // read all labels | |||||
| byte[] labels = new byte[lnumber]; | |||||
| readFile(lbin, labels, lnumber); | |||||
| // row, column | |||||
| readFile(ibin, bytes, 4); | |||||
| int n_rows = Integer.parseInt(bytesToHex(bytes), 16); | |||||
| readFile(ibin, bytes, 4); | |||||
| int n_cols = Integer.parseInt(bytesToHex(bytes), 16); | |||||
| if (n_rows != 28 || n_cols != 28) { | |||||
| System.err.println("invalid n_rows: " + n_rows + " n_cols: " + n_cols); | |||||
| return; | |||||
| } | |||||
| // read images | |||||
| int image_size = n_rows * n_cols; | |||||
| byte[] image_data = new byte[image_size]; | |||||
| for (int i = 0; i < lnumber; i++) { | |||||
| float [] hwc_bin_image = new float[32 * 32]; | |||||
| readFile(ibin, image_data, image_size); | |||||
| for (int r = 0; r < 32; r++) { | |||||
| for (int c = 0; c < 32; c++) { | |||||
| int index = r * 32 + c; | |||||
| if (r < 2 || r > 29 || c < 2 || c > 29) { | |||||
| hwc_bin_image[index] = 0; | |||||
| } else { | |||||
| int data = image_data[(r-2)*28 + (c-2)] & 0xff; | |||||
| hwc_bin_image[index] = (float)data / 255.0f; | |||||
| } | |||||
| } | |||||
| } | |||||
| DataLabelTuple data_label_tupel = new DataLabelTuple(); | |||||
| data_label_tupel.data = hwc_bin_image; | |||||
| data_label_tupel.label = labels[i] & 0xff; | |||||
| dataset.add(data_label_tupel); | |||||
| } | |||||
| } catch (IOException e) { | |||||
| System.err.println("Read Dateset exception"); | |||||
| } | |||||
| } | |||||
| public void setExpectedDataSize(long data_size) { | |||||
| expectedDataSize = data_size; | |||||
| } | |||||
| public long getNumOfClasses() { | |||||
| return numOfClasses; | |||||
| } | |||||
| public Vector<DataLabelTuple> getTestData() { | |||||
| return testData; | |||||
| } | |||||
| public Vector<DataLabelTuple> getTrainData() { | |||||
| return trainData; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,20 @@ | |||||
| package com.mindspore.lite.train_lenet; | |||||
| import com.mindspore.lite.Version; | |||||
| public class Main { | |||||
| public static void main(String[] args) { | |||||
| System.out.println(Version.version()); | |||||
| if (args.length < 2) { | |||||
| System.err.println("model path and dataset path must be provided."); | |||||
| return; | |||||
| } | |||||
| String modelPath = args[0]; | |||||
| String datasetPath = args[1]; | |||||
| NetRunner net_runner = new NetRunner(); | |||||
| net_runner.trainModel(modelPath, datasetPath); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,220 @@ | |||||
| package com.mindspore.lite.train_lenet; | |||||
| import com.mindspore.lite.MSTensor; | |||||
| import com.mindspore.lite.TrainSession; | |||||
| import com.mindspore.lite.config.MSConfig; | |||||
| import java.nio.ByteBuffer; | |||||
| import java.nio.ByteOrder; | |||||
| import java.util.List; | |||||
| import java.util.Map; | |||||
| import java.util.Vector; | |||||
| public class NetRunner { | |||||
| private int dataIndex = 0; | |||||
| private int labelIndex = 1; | |||||
| private TrainSession session; | |||||
| private long batchSize; | |||||
| private long dataSize; // one input data size, in byte | |||||
| private DataSet ds = new DataSet(); | |||||
| private long numOfClasses; | |||||
| private long cycles = 3500; | |||||
| private int idx = 1; | |||||
| private String trainedFilePath = "trained.ms"; | |||||
| public void initAndFigureInputs(String modelPath) { | |||||
| MSConfig msConfig = new MSConfig(); | |||||
| // arg 0: DeviceType:DT_CPU -> 0 | |||||
| // arg 1: ThreadNum -> 2 | |||||
| // arg 2: cpuBindMode:NO_BIND -> 0 | |||||
| // arg 3: enable_fp16 -> false | |||||
| msConfig.init(0, 2, 0, false); | |||||
| session = new TrainSession(); | |||||
| session.init(modelPath, msConfig); | |||||
| session.setLearningRate(0.01f); | |||||
| List<MSTensor> inputs = session.getInputs(); | |||||
| if (inputs.size() <= 1) { | |||||
| System.err.println("model input size: " + inputs.size()); | |||||
| return; | |||||
| } | |||||
| dataIndex = 0; | |||||
| labelIndex = 1; | |||||
| batchSize = inputs.get(dataIndex).getShape()[0]; | |||||
| dataSize = inputs.get(dataIndex).size() / batchSize; | |||||
| System.out.println("batch_size: " + batchSize); | |||||
| int index = modelPath.lastIndexOf(".ms"); | |||||
| if (index == -1) { | |||||
| System.out.println("The model " + modelPath + " should be named *.ms"); | |||||
| return; | |||||
| } | |||||
| trainedFilePath = modelPath.substring(0, index) + "_trained.ms"; | |||||
| } | |||||
| public int initDB(String datasetPath) { | |||||
| if (dataSize != 0) { | |||||
| ds.setExpectedDataSize(dataSize); | |||||
| } | |||||
| ds.initializeMNISTDatabase(datasetPath); | |||||
| numOfClasses = ds.getNumOfClasses(); | |||||
| if (numOfClasses != 10) { | |||||
| System.err.println("unexpected num_of_class: " + numOfClasses); | |||||
| System.exit(1); | |||||
| } | |||||
| if (ds.testData.size() == 0) { | |||||
| System.err.println("test data size is 0"); | |||||
| return -1; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| public float getLoss() { | |||||
| MSTensor tensor = searchOutputsForSize(1); | |||||
| return tensor.getFloatData()[0]; | |||||
| } | |||||
| private MSTensor searchOutputsForSize(int size) { | |||||
| Map<String, MSTensor> outputs = session.getOutputMapByTensor(); | |||||
| for (MSTensor tensor : outputs.values()) { | |||||
| if (tensor.elementsNum() == size) { | |||||
| return tensor; | |||||
| } | |||||
| } | |||||
| System.err.println("can not find output the tensor which element num is " + size); | |||||
| return null; | |||||
| } | |||||
| public int trainLoop() { | |||||
| session.train(); | |||||
| float min_loss = 1000; | |||||
| float max_acc = 0; | |||||
| for (int i = 0; i < cycles; i++) { | |||||
| fillInputData(ds.getTrainData(), false); | |||||
| session.runGraph(); | |||||
| float loss = getLoss(); | |||||
| if (min_loss > loss) { | |||||
| min_loss = loss; | |||||
| } | |||||
| if ((i + 1) % 500 == 0) { | |||||
| float acc = calculateAccuracy(10); // only test 10 batch size | |||||
| if (max_acc < acc) { | |||||
| max_acc = acc; | |||||
| } | |||||
| System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc); | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| public float calculateAccuracy(long maxTests) { | |||||
| float accuracy = 0; | |||||
| Vector<DataSet.DataLabelTuple> test_set = ds.getTestData(); | |||||
| long tests = test_set.size() / batchSize; | |||||
| if (maxTests != -1 && tests < maxTests) { | |||||
| tests = maxTests; | |||||
| } | |||||
| session.eval(); | |||||
| for (long i = 0; i < tests; i++) { | |||||
| Vector<Integer> labels = fillInputData(test_set, (maxTests == -1)); | |||||
| if (labels.size() != batchSize) { | |||||
| System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize); | |||||
| System.exit(1); | |||||
| } | |||||
| session.runGraph(); | |||||
| MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses)); | |||||
| if (outputsv == null) { | |||||
| System.err.println("can not find output tensor with size: " + batchSize * numOfClasses); | |||||
| System.exit(1); | |||||
| } | |||||
| float[] scores = outputsv.getFloatData(); | |||||
| for (int b = 0; b < batchSize; b++) { | |||||
| int max_idx = 0; | |||||
| float max_score = scores[(int) (numOfClasses * b)]; | |||||
| for (int c = 0; c < numOfClasses; c++) { | |||||
| if (scores[(int) (numOfClasses * b + c)] > max_score) { | |||||
| max_score = scores[(int) (numOfClasses * b + c)]; | |||||
| max_idx = c; | |||||
| } | |||||
| } | |||||
| if (labels.get(b) == max_idx) { | |||||
| accuracy += 1.0; | |||||
| } | |||||
| } | |||||
| } | |||||
| session.train(); | |||||
| accuracy /= (batchSize * tests); | |||||
| return accuracy; | |||||
| } | |||||
| // each time fill batch_size data | |||||
| Vector<Integer> fillInputData(Vector<DataSet.DataLabelTuple> dataset, boolean serially) { | |||||
| Vector<Integer> labelsVec = new Vector<Integer>(); | |||||
| int totalSize = dataset.size(); | |||||
| List<MSTensor> inputs = session.getInputs(); | |||||
| int inputDataCnt = inputs.get(dataIndex).elementsNum(); | |||||
| float[] inputBatchData = new float[inputDataCnt]; | |||||
| int labelDataCnt = inputs.get(labelIndex).elementsNum(); | |||||
| int[] labelBatchData = new int[labelDataCnt]; | |||||
| for (int i = 0; i < batchSize; i++) { | |||||
| if (serially) { | |||||
| idx = (++idx) % totalSize; | |||||
| } else { | |||||
| idx = (int) (Math.random() * totalSize); | |||||
| } | |||||
| int label = 0; | |||||
| DataSet.DataLabelTuple dataLabelTuple = dataset.get(idx); | |||||
| label = dataLabelTuple.label; | |||||
| System.arraycopy(dataLabelTuple.data, 0, inputBatchData, (int) (i * dataLabelTuple.data.length), dataLabelTuple.data.length); | |||||
| labelBatchData[i] = label; | |||||
| labelsVec.add(label); | |||||
| } | |||||
| ByteBuffer byteBuf = ByteBuffer.allocateDirect(inputBatchData.length * Float.BYTES); | |||||
| byteBuf.order(ByteOrder.nativeOrder()); | |||||
| for (int i = 0; i < inputBatchData.length; i++) { | |||||
| byteBuf.putFloat(inputBatchData[i]); | |||||
| } | |||||
| inputs.get(dataIndex).setData(byteBuf); | |||||
| ByteBuffer labelByteBuf = ByteBuffer.allocateDirect(labelBatchData.length * 4); | |||||
| labelByteBuf.order(ByteOrder.nativeOrder()); | |||||
| for (int i = 0; i < labelBatchData.length; i++) { | |||||
| labelByteBuf.putInt(labelBatchData[i]); | |||||
| } | |||||
| inputs.get(labelIndex).setData(labelByteBuf); | |||||
| return labelsVec; | |||||
| } | |||||
| public void trainModel(String modelPath, String datasetPath) { | |||||
| System.out.println("==========Loading Model, Create Train Session============="); | |||||
| initAndFigureInputs(modelPath); | |||||
| System.out.println("==========Initing DataSet================"); | |||||
| initDB(datasetPath); | |||||
| System.out.println("==========Training Model==================="); | |||||
| trainLoop(); | |||||
| System.out.println("==========Evaluating The Trained Model============"); | |||||
| float acc = calculateAccuracy(-1); | |||||
| System.out.println("accuracy = " + acc); | |||||
| if (cycles > 0) { | |||||
| if (session.saveToFile(trainedFilePath)) { | |||||
| System.out.println("Trained model successfully saved: " + trainedFilePath); | |||||
| } else { | |||||
| System.err.println("Save model error."); | |||||
| } | |||||
| } | |||||
| session.free(); | |||||
| } | |||||
| } | |||||