| @@ -671,6 +671,12 @@ build_lite_java_arm64() { | |||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/arm64-v8a/ | |||
| else | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/arm64-v8a/ | |||
| @@ -697,6 +703,12 @@ build_lite_java_arm32() { | |||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | |||
| else | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/armeabi-v7a/ | |||
| @@ -706,10 +718,15 @@ build_lite_java_arm32() { | |||
| build_lite_java_x86() { | |||
| # build mindspore-lite x86 | |||
| local inference_or_train=inference | |||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||
| inference_or_train=train | |||
| fi | |||
| if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then | |||
| local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD} | |||
| local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD} | |||
| else | |||
| local JTARBALL=mindspore-lite-${VERSION_STR}-inference-linux-x64 | |||
| local JTARBALL=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64 | |||
| fi | |||
| if [[ "X$INC_BUILD" == "Xoff" ]] || [[ ! -f "${BASEPATH}/mindspore/lite/build/java/${JTARBALL}.tar.gz" ]]; then | |||
| build_lite "x86_64" "off" "" | |||
| @@ -721,8 +738,20 @@ build_lite_java_x86() { | |||
| [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/linux_x86/libs/ | |||
| mkdir -p ${JAVA_PATH}/java/linux_x86/libs/ | |||
| mkdir -p ${JAVA_PATH}/native/libs/linux_x86/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/lib/libminddata-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/java/linux_x86/libs/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/train/minddata/third_party/libjpeg-turbo/lib/*.so* ${JAVA_PATH}/native/libs/linux_x86/ | |||
| else | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/java/linux_x86/libs/ | |||
| cp ${BASEPATH}/mindspore/lite/build/java/${JTARBALL}/inference/lib/libmindspore-lite.so ${JAVA_PATH}/native/libs/linux_x86/ | |||
| fi | |||
| [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL} | |||
| } | |||
| build_jni_arm64() { | |||
| @@ -776,7 +805,7 @@ build_jni_x86_64() { | |||
| mkdir -pv java/jni | |||
| cd java/jni | |||
| cmake -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | |||
| -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${JAVA_PATH}/native/" | |||
| -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} "${JAVA_PATH}/native/" | |||
| make -j$THREAD_NUM | |||
| if [[ $? -ne 0 ]]; then | |||
| echo "---------------- mindspore lite: build jni x86_64 failed----------------" | |||
| @@ -825,11 +854,16 @@ build_java() { | |||
| cd ${JAVA_PATH}/java/app/build | |||
| zip -r mindspore-lite-maven-${VERSION_STR}.zip mindspore | |||
| local inference_or_train=inference | |||
| if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then | |||
| inference_or_train=train | |||
| fi | |||
| # build linux x86 jar | |||
| if [[ "$X86_64_SIMD" == "sse" || "$X86_64_SIMD" == "avx" ]]; then | |||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-${X86_64_SIMD}-jar | |||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-${X86_64_SIMD}-jar | |||
| else | |||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-inference-linux-x64-jar | |||
| local LINUX_X86_PACKAGE_NAME=mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64-jar | |||
| fi | |||
| check_java_home | |||
| build_lite_java_x86 | |||
| @@ -843,15 +877,17 @@ build_java() { | |||
| gradle releaseJar | |||
| # install and package | |||
| mkdir -p ${JAVA_PATH}/java/linux_x86/build/lib | |||
| cp ${JAVA_PATH}/java/linux_x86/libs/*.so ${JAVA_PATH}/java/linux_x86/build/lib/jar | |||
| cp ${JAVA_PATH}/java/linux_x86/libs/*.so* ${JAVA_PATH}/java/linux_x86/build/lib/jar | |||
| cd ${JAVA_PATH}/java/linux_x86/build/ | |||
| cp -r ${JAVA_PATH}/java/linux_x86/build/lib ${JAVA_PATH}/java/linux_x86/build/${LINUX_X86_PACKAGE_NAME} | |||
| tar czvf ${LINUX_X86_PACKAGE_NAME}.tar.gz ${LINUX_X86_PACKAGE_NAME} | |||
| # copy output | |||
| cp ${JAVA_PATH}/java/app/build/mindspore-lite-maven-${VERSION_STR}.zip ${BASEPATH}/output | |||
| cp ${LINUX_X86_PACKAGE_NAME}.tar.gz ${BASEPATH}/output | |||
| cd ${BASEPATH}/output | |||
| [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-inference-linux-x64 | |||
| [ -n "${VERSION_STR}" ] && rm -rf ${BASEPATH}/mindspore/lite/build/java/mindspore-lite-${VERSION_STR}-${inference_or_train}-linux-x64 | |||
| exit 0 | |||
| } | |||
| @@ -0,0 +1,55 @@ | |||
| <?xml version="1.0" encoding="UTF-8"?> | |||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | |||
| xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | |||
| xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | |||
| <modelVersion>4.0.0</modelVersion> | |||
| <groupId>com.mindspore.lite.demo</groupId> | |||
| <artifactId>train_lenet_java</artifactId> | |||
| <version>1.0</version> | |||
| <properties> | |||
| <maven.compiler.source>8</maven.compiler.source> | |||
| <maven.compiler.target>8</maven.compiler.target> | |||
| </properties> | |||
| <dependencies> | |||
| <dependency> | |||
| <groupId>com.mindspore.lite</groupId> | |||
| <artifactId>mindspore-lite-java</artifactId> | |||
| <version>1.0</version> | |||
| <scope>system</scope> | |||
| <systemPath>${project.basedir}/lib/mindspore-lite-java.jar</systemPath> | |||
| </dependency> | |||
| </dependencies> | |||
| <build> | |||
| <finalName>${project.name}</finalName> | |||
| <plugins> | |||
| <plugin> | |||
| <groupId>org.apache.maven.plugins</groupId> | |||
| <artifactId>maven-assembly-plugin</artifactId> | |||
| <configuration> | |||
| <archive> | |||
| <manifest> | |||
| <mainClass>com.mindspore.lite.train_lenet.Main</mainClass> | |||
| </manifest> | |||
| </archive> | |||
| <descriptorRefs> | |||
| <descriptorRef>jar-with-dependencies</descriptorRef> | |||
| </descriptorRefs> | |||
| </configuration> | |||
| <executions> | |||
| <execution> | |||
| <id>make-assemble</id> | |||
| <phase>package</phase> | |||
| <goals> | |||
| <goal>single</goal> | |||
| </goals> | |||
| </execution> | |||
| </executions> | |||
| </plugin> | |||
| </plugins> | |||
| </build> | |||
| </project> | |||
| @@ -0,0 +1,131 @@ | |||
| package com.mindspore.lite.train_lenet; | |||
| import java.io.BufferedInputStream; | |||
| import java.io.FileInputStream; | |||
| import java.io.IOException; | |||
| import java.util.Vector; | |||
| public class DataSet { | |||
| private long numOfClasses = 0; | |||
| private long expectedDataSize = 0; | |||
| public class DataLabelTuple { | |||
| public float[] data; | |||
| public int label; | |||
| } | |||
| Vector<DataLabelTuple> trainData; | |||
| Vector<DataLabelTuple> testData; | |||
| public void initializeMNISTDatabase(String dpath) { | |||
| numOfClasses = 10; | |||
| trainData = new Vector<DataLabelTuple>(); | |||
| testData = new Vector<DataLabelTuple>(); | |||
| readMNISTFile(dpath + "/train/train-images-idx3-ubyte", dpath+"/train/train-labels-idx1-ubyte", trainData); | |||
| readMNISTFile(dpath + "/test/t10k-images-idx3-ubyte", dpath+"/test/t10k-labels-idx1-ubyte", testData); | |||
| System.out.println("train data cnt: " + trainData.size()); | |||
| System.out.println("test data cnt: " + testData.size()); | |||
| } | |||
| private String bytesToHex(byte[] bytes) { | |||
| StringBuffer sb = new StringBuffer(); | |||
| for (int i = 0; i < bytes.length; i++) { | |||
| String hex = Integer.toHexString(bytes[i] & 0xFF); | |||
| if (hex.length() < 2) { | |||
| sb.append(0); | |||
| } | |||
| sb.append(hex); | |||
| } | |||
| return sb.toString(); | |||
| } | |||
| private void readFile(BufferedInputStream inputStream, byte[] bytes, int len) throws IOException { | |||
| int result = inputStream.read(bytes, 0, len); | |||
| if (result != len) { | |||
| System.err.println("expected read " + len + " bytes, but " + result + " read"); | |||
| System.exit(1); | |||
| } | |||
| } | |||
| public void readMNISTFile(String inputFileName, String labelFileName, Vector<DataLabelTuple> dataset) { | |||
| try { | |||
| BufferedInputStream ibin = new BufferedInputStream(new FileInputStream(inputFileName)); | |||
| BufferedInputStream lbin = new BufferedInputStream(new FileInputStream(labelFileName)); | |||
| byte[] bytes = new byte[4]; | |||
| readFile(ibin, bytes, 4); | |||
| if (!"00000803".equals(bytesToHex(bytes))) { // 2051 | |||
| System.err.println("The dataset is not valid: " + bytesToHex(bytes)); | |||
| return; | |||
| } | |||
| readFile(ibin, bytes, 4); | |||
| int inumber = Integer.parseInt(bytesToHex(bytes), 16); | |||
| readFile(lbin, bytes, 4); | |||
| if (!"00000801".equals(bytesToHex(bytes))) { // 2049 | |||
| System.err.println("The dataset label is not valid: " + bytesToHex(bytes)); | |||
| return; | |||
| } | |||
| readFile(lbin, bytes, 4); | |||
| int lnumber = Integer.parseInt(bytesToHex(bytes), 16); | |||
| if (inumber != lnumber) { | |||
| System.err.println("input data cnt: " + inumber + " not equal label cnt: " + lnumber); | |||
| return; | |||
| } | |||
| // read all labels | |||
| byte[] labels = new byte[lnumber]; | |||
| readFile(lbin, labels, lnumber); | |||
| // row, column | |||
| readFile(ibin, bytes, 4); | |||
| int n_rows = Integer.parseInt(bytesToHex(bytes), 16); | |||
| readFile(ibin, bytes, 4); | |||
| int n_cols = Integer.parseInt(bytesToHex(bytes), 16); | |||
| if (n_rows != 28 || n_cols != 28) { | |||
| System.err.println("invalid n_rows: " + n_rows + " n_cols: " + n_cols); | |||
| return; | |||
| } | |||
| // read images | |||
| int image_size = n_rows * n_cols; | |||
| byte[] image_data = new byte[image_size]; | |||
| for (int i = 0; i < lnumber; i++) { | |||
| float [] hwc_bin_image = new float[32 * 32]; | |||
| readFile(ibin, image_data, image_size); | |||
| for (int r = 0; r < 32; r++) { | |||
| for (int c = 0; c < 32; c++) { | |||
| int index = r * 32 + c; | |||
| if (r < 2 || r > 29 || c < 2 || c > 29) { | |||
| hwc_bin_image[index] = 0; | |||
| } else { | |||
| int data = image_data[(r-2)*28 + (c-2)] & 0xff; | |||
| hwc_bin_image[index] = (float)data / 255.0f; | |||
| } | |||
| } | |||
| } | |||
| DataLabelTuple data_label_tupel = new DataLabelTuple(); | |||
| data_label_tupel.data = hwc_bin_image; | |||
| data_label_tupel.label = labels[i] & 0xff; | |||
| dataset.add(data_label_tupel); | |||
| } | |||
| } catch (IOException e) { | |||
| System.err.println("Read Dateset exception"); | |||
| } | |||
| } | |||
| public void setExpectedDataSize(long data_size) { | |||
| expectedDataSize = data_size; | |||
| } | |||
| public long getNumOfClasses() { | |||
| return numOfClasses; | |||
| } | |||
| public Vector<DataLabelTuple> getTestData() { | |||
| return testData; | |||
| } | |||
| public Vector<DataLabelTuple> getTrainData() { | |||
| return trainData; | |||
| } | |||
| } | |||
| @@ -0,0 +1,20 @@ | |||
| package com.mindspore.lite.train_lenet; | |||
| import com.mindspore.lite.Version; | |||
| public class Main { | |||
| public static void main(String[] args) { | |||
| System.out.println(Version.version()); | |||
| if (args.length < 2) { | |||
| System.err.println("model path and dataset path must be provided."); | |||
| return; | |||
| } | |||
| String modelPath = args[0]; | |||
| String datasetPath = args[1]; | |||
| NetRunner net_runner = new NetRunner(); | |||
| net_runner.trainModel(modelPath, datasetPath); | |||
| } | |||
| } | |||
| @@ -0,0 +1,220 @@ | |||
| package com.mindspore.lite.train_lenet; | |||
| import com.mindspore.lite.MSTensor; | |||
| import com.mindspore.lite.TrainSession; | |||
| import com.mindspore.lite.config.MSConfig; | |||
| import java.nio.ByteBuffer; | |||
| import java.nio.ByteOrder; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.Vector; | |||
| public class NetRunner { | |||
| private int dataIndex = 0; | |||
| private int labelIndex = 1; | |||
| private TrainSession session; | |||
| private long batchSize; | |||
| private long dataSize; // one input data size, in byte | |||
| private DataSet ds = new DataSet(); | |||
| private long numOfClasses; | |||
| private long cycles = 3500; | |||
| private int idx = 1; | |||
| private String trainedFilePath = "trained.ms"; | |||
| public void initAndFigureInputs(String modelPath) { | |||
| MSConfig msConfig = new MSConfig(); | |||
| // arg 0: DeviceType:DT_CPU -> 0 | |||
| // arg 1: ThreadNum -> 2 | |||
| // arg 2: cpuBindMode:NO_BIND -> 0 | |||
| // arg 3: enable_fp16 -> false | |||
| msConfig.init(0, 2, 0, false); | |||
| session = new TrainSession(); | |||
| session.init(modelPath, msConfig); | |||
| session.setLearningRate(0.01f); | |||
| List<MSTensor> inputs = session.getInputs(); | |||
| if (inputs.size() <= 1) { | |||
| System.err.println("model input size: " + inputs.size()); | |||
| return; | |||
| } | |||
| dataIndex = 0; | |||
| labelIndex = 1; | |||
| batchSize = inputs.get(dataIndex).getShape()[0]; | |||
| dataSize = inputs.get(dataIndex).size() / batchSize; | |||
| System.out.println("batch_size: " + batchSize); | |||
| int index = modelPath.lastIndexOf(".ms"); | |||
| if (index == -1) { | |||
| System.out.println("The model " + modelPath + " should be named *.ms"); | |||
| return; | |||
| } | |||
| trainedFilePath = modelPath.substring(0, index) + "_trained.ms"; | |||
| } | |||
| public int initDB(String datasetPath) { | |||
| if (dataSize != 0) { | |||
| ds.setExpectedDataSize(dataSize); | |||
| } | |||
| ds.initializeMNISTDatabase(datasetPath); | |||
| numOfClasses = ds.getNumOfClasses(); | |||
| if (numOfClasses != 10) { | |||
| System.err.println("unexpected num_of_class: " + numOfClasses); | |||
| System.exit(1); | |||
| } | |||
| if (ds.testData.size() == 0) { | |||
| System.err.println("test data size is 0"); | |||
| return -1; | |||
| } | |||
| return 0; | |||
| } | |||
| public float getLoss() { | |||
| MSTensor tensor = searchOutputsForSize(1); | |||
| return tensor.getFloatData()[0]; | |||
| } | |||
| private MSTensor searchOutputsForSize(int size) { | |||
| Map<String, MSTensor> outputs = session.getOutputMapByTensor(); | |||
| for (MSTensor tensor : outputs.values()) { | |||
| if (tensor.elementsNum() == size) { | |||
| return tensor; | |||
| } | |||
| } | |||
| System.err.println("can not find output the tensor which element num is " + size); | |||
| return null; | |||
| } | |||
| public int trainLoop() { | |||
| session.train(); | |||
| float min_loss = 1000; | |||
| float max_acc = 0; | |||
| for (int i = 0; i < cycles; i++) { | |||
| fillInputData(ds.getTrainData(), false); | |||
| session.runGraph(); | |||
| float loss = getLoss(); | |||
| if (min_loss > loss) { | |||
| min_loss = loss; | |||
| } | |||
| if ((i + 1) % 500 == 0) { | |||
| float acc = calculateAccuracy(10); // only test 10 batch size | |||
| if (max_acc < acc) { | |||
| max_acc = acc; | |||
| } | |||
| System.out.println("step_" + (i + 1) + ": \tLoss is " + loss + " [min=" + min_loss + "]" + " max_accc=" + max_acc); | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| public float calculateAccuracy(long maxTests) { | |||
| float accuracy = 0; | |||
| Vector<DataSet.DataLabelTuple> test_set = ds.getTestData(); | |||
| long tests = test_set.size() / batchSize; | |||
| if (maxTests != -1 && tests < maxTests) { | |||
| tests = maxTests; | |||
| } | |||
| session.eval(); | |||
| for (long i = 0; i < tests; i++) { | |||
| Vector<Integer> labels = fillInputData(test_set, (maxTests == -1)); | |||
| if (labels.size() != batchSize) { | |||
| System.err.println("unexpected labels size: " + labels.size() + " batch_size size: " + batchSize); | |||
| System.exit(1); | |||
| } | |||
| session.runGraph(); | |||
| MSTensor outputsv = searchOutputsForSize((int) (batchSize * numOfClasses)); | |||
| if (outputsv == null) { | |||
| System.err.println("can not find output tensor with size: " + batchSize * numOfClasses); | |||
| System.exit(1); | |||
| } | |||
| float[] scores = outputsv.getFloatData(); | |||
| for (int b = 0; b < batchSize; b++) { | |||
| int max_idx = 0; | |||
| float max_score = scores[(int) (numOfClasses * b)]; | |||
| for (int c = 0; c < numOfClasses; c++) { | |||
| if (scores[(int) (numOfClasses * b + c)] > max_score) { | |||
| max_score = scores[(int) (numOfClasses * b + c)]; | |||
| max_idx = c; | |||
| } | |||
| } | |||
| if (labels.get(b) == max_idx) { | |||
| accuracy += 1.0; | |||
| } | |||
| } | |||
| } | |||
| session.train(); | |||
| accuracy /= (batchSize * tests); | |||
| return accuracy; | |||
| } | |||
| // each time fill batch_size data | |||
| Vector<Integer> fillInputData(Vector<DataSet.DataLabelTuple> dataset, boolean serially) { | |||
| Vector<Integer> labelsVec = new Vector<Integer>(); | |||
| int totalSize = dataset.size(); | |||
| List<MSTensor> inputs = session.getInputs(); | |||
| int inputDataCnt = inputs.get(dataIndex).elementsNum(); | |||
| float[] inputBatchData = new float[inputDataCnt]; | |||
| int labelDataCnt = inputs.get(labelIndex).elementsNum(); | |||
| int[] labelBatchData = new int[labelDataCnt]; | |||
| for (int i = 0; i < batchSize; i++) { | |||
| if (serially) { | |||
| idx = (++idx) % totalSize; | |||
| } else { | |||
| idx = (int) (Math.random() * totalSize); | |||
| } | |||
| int label = 0; | |||
| DataSet.DataLabelTuple dataLabelTuple = dataset.get(idx); | |||
| label = dataLabelTuple.label; | |||
| System.arraycopy(dataLabelTuple.data, 0, inputBatchData, (int) (i * dataLabelTuple.data.length), dataLabelTuple.data.length); | |||
| labelBatchData[i] = label; | |||
| labelsVec.add(label); | |||
| } | |||
| ByteBuffer byteBuf = ByteBuffer.allocateDirect(inputBatchData.length * Float.BYTES); | |||
| byteBuf.order(ByteOrder.nativeOrder()); | |||
| for (int i = 0; i < inputBatchData.length; i++) { | |||
| byteBuf.putFloat(inputBatchData[i]); | |||
| } | |||
| inputs.get(dataIndex).setData(byteBuf); | |||
| ByteBuffer labelByteBuf = ByteBuffer.allocateDirect(labelBatchData.length * 4); | |||
| labelByteBuf.order(ByteOrder.nativeOrder()); | |||
| for (int i = 0; i < labelBatchData.length; i++) { | |||
| labelByteBuf.putInt(labelBatchData[i]); | |||
| } | |||
| inputs.get(labelIndex).setData(labelByteBuf); | |||
| return labelsVec; | |||
| } | |||
| public void trainModel(String modelPath, String datasetPath) { | |||
| System.out.println("==========Loading Model, Create Train Session============="); | |||
| initAndFigureInputs(modelPath); | |||
| System.out.println("==========Initing DataSet================"); | |||
| initDB(datasetPath); | |||
| System.out.println("==========Training Model==================="); | |||
| trainLoop(); | |||
| System.out.println("==========Evaluating The Trained Model============"); | |||
| float acc = calculateAccuracy(-1); | |||
| System.out.println("accuracy = " + acc); | |||
| if (cycles > 0) { | |||
| if (session.saveToFile(trainedFilePath)) { | |||
| System.out.println("Trained model successfully saved: " + trainedFilePath); | |||
| } else { | |||
| System.err.println("Save model error."); | |||
| } | |||
| } | |||
| session.free(); | |||
| } | |||
| } | |||