Browse Source

add java interface

tags/v0.7.0-beta
hangq 5 years ago
parent
commit
7c52b49c08
17 changed files with 970 additions and 5 deletions
  1. +21
    -0
      mindspore/lite/java/src/java/cn/huawei/mindspore/DataType.java
  2. +114
    -0
      mindspore/lite/java/src/java/cn/huawei/mindspore/LiteSession.java
  3. +75
    -0
      mindspore/lite/java/src/java/cn/huawei/mindspore/MSTensor.java
  4. +76
    -0
      mindspore/lite/java/src/java/cn/huawei/mindspore/Model.java
  5. +43
    -0
      mindspore/lite/java/src/java/cn/huawei/mindspore/context/Context.java
  6. +7
    -0
      mindspore/lite/java/src/java/cn/huawei/mindspore/context/CpuBindMode.java
  7. +7
    -0
      mindspore/lite/java/src/java/cn/huawei/mindspore/context/DeviceType.java
  8. +18
    -0
      mindspore/lite/java/src/native/CMakeLists.txt
  9. +37
    -0
      mindspore/lite/java/src/native/common/jni_utils.cpp
  10. +26
    -0
      mindspore/lite/java/src/native/common/jni_utils.h
  11. +19
    -0
      mindspore/lite/java/src/native/common/ms_log.cpp
  12. +36
    -0
      mindspore/lite/java/src/native/common/ms_log.h
  13. +69
    -0
      mindspore/lite/java/src/native/runtime/context.cpp
  14. +196
    -0
      mindspore/lite/java/src/native/runtime/lite_session.cpp
  15. +50
    -0
      mindspore/lite/java/src/native/runtime/model.cpp
  16. +172
    -0
      mindspore/lite/java/src/native/runtime/ms_tensor.cpp
  17. +4
    -5
      mindspore/lite/src/executor.cc

+ 21
- 0
mindspore/lite/java/src/java/cn/huawei/mindspore/DataType.java View File

@@ -0,0 +1,21 @@
package cn.huawei.mindspore;

public class DataType {
public static final int kNumberTypeBool = 30;
public static final int kNumberTypeInt = 31;
public static final int kNumberTypeInt8 = 32;
public static final int kNumberTypeInt16 = 33;
public static final int kNumberTypeInt32 = 34;
public static final int kNumberTypeInt64 = 35;
public static final int kNumberTypeUInt = 36;
public static final int kNumberTypeUInt8 = 37;
public static final int kNumberTypeUInt16 = 38;
public static final int kNumberTypeUint32 = 39;
public static final int kNumberTypeUInt64 = 40;
public static final int kNumberTypeFloat = 41;
public static final int kNumberTypeFloat16 = 42;
public static final int kNumberTypeFloat32 = 43;
public static final int kNumberTypeFloat64 = 44;

public static native int elementSize(int elementType);
}

+ 114
- 0
mindspore/lite/java/src/java/cn/huawei/mindspore/LiteSession.java View File

@@ -0,0 +1,114 @@
package cn.huawei.mindspore;

import cn.huawei.mindspore.context.Context;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class LiteSession {
static {
System.loadLibrary("mindspore-lite-jni");
}

private long sessionPtr;

LiteSession() {
this.sessionPtr = 0;
}

public boolean init(Context context) {
this.sessionPtr = createSession(context.getContextPtr());
return this.sessionPtr != 0;
}

public long getSessionPtr() {
return sessionPtr;
}

public void bindThread(boolean if_bind) {
this.bindThread(this.sessionPtr, if_bind);
}

public boolean compileGraph(Model model) {
return this.compileGraph(this.sessionPtr, model.getModelPtr());
}

public boolean runGraph() {
return this.runGraph(this.sessionPtr);
}

public List<MSTensor> getInputs() {
List<Long> ret = this.getInputs(this.sessionPtr);
ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
for (Long ms_tensor_addr : ret) {
MSTensor msTensor = new MSTensor(ms_tensor_addr);
tensors.add(msTensor);
}
return tensors;
}

public List<MSTensor> getInputsByName(String nodeName) {
List<Long> ret = this.getInputsByName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
return tensors;
}

public Map<String, List<MSTensor>> getOutputs() {
Map<String, List<Long>> ret = this.getOutputs(this.sessionPtr);
Map<String, List<MSTensor>> tensorMap = new HashMap<>();
Set<Map.Entry<String, List<Long>>> entrySet = ret.entrySet();
for (Map.Entry<String, List<Long>> entry : entrySet) {
String name = entry.getKey();
List<Long> msTensorAddrs = entry.getValue();
ArrayList<MSTensor> msTensors = new ArrayList<>();
for (Long msTensorAddr : msTensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr);
msTensors.add(msTensor);
}
tensorMap.put(name, msTensors);
}
return tensorMap;
}

public List<MSTensor> getOutputsByName(String nodeName) {
List<Long> ret = this.getOutputsByName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
return tensors;
}

public void free() {
this.free(this.sessionPtr);
this.sessionPtr = 0;
}

private native long createSession(long contextPtr);

private native boolean compileGraph(long sessionPtr, long modelPtr);

private native void bindThread(long sessionPtr, boolean if_bind);

private native boolean runGraph(long sessionPtr);

private native List<Long> getInputs(long sessionPtr);

private native List<Long> getInputsByName(long sessionPtr, String nodeName);

private native Map<String, List<Long>> getOutputs(long sessionPtr);

private native List<Long> getOutputsByName(long sessionPtr, String nodeName);

private native void free(long sessionPtr);
}

+ 75
- 0
mindspore/lite/java/src/java/cn/huawei/mindspore/MSTensor.java View File

@@ -0,0 +1,75 @@
package cn.huawei.mindspore;

public class MSTensor {
private long tensorPtr;

public MSTensor() {
this.tensorPtr = 0;
}

public MSTensor(long tensorPtr) {
this.tensorPtr = tensorPtr;
}

public boolean init (int dataType, int[] shape) {
this.tensorPtr = createMSTensor(dataType, shape, shape.length);
return this.tensorPtr != 0;
}

public int[] getShape() {
return this.getShape(this.tensorPtr);
}

public void setShape(int[] shape) {
this.setShape(this.tensorPtr, shape, shape.length);
}

public int getDataType() {
return this.getDataType(this.tensorPtr);
}

public void setDataType(int dataType) {
this.setDataType(this.tensorPtr, dataType);
}

public byte[] getData() {
return this.getData(this.tensorPtr);
}

public void setData(byte[] data) {
this.setData(this.tensorPtr, data, data.length);
}

public long size() {
return this.size(this.tensorPtr);
}

public int elementsNum() {
return this.elementsNum(this.tensorPtr);
}

public void free() {
this.free(this.tensorPtr);
this.tensorPtr = 0;
}

private native long createMSTensor(int dataType, int[] shape, int shapeLen);

private native int[] getShape(long tensorPtr);

private native boolean setShape(long tensorPtr, int[] shape, int shapeLen);

private native int getDataType(long tensorPtr);

private native boolean setDataType(long tensorPtr, int dataType);

private native byte[] getData(long tensorPtr);

private native boolean setData(long tensorPtr, byte[] data, long dataLen);

private native long size(long tensorPtr);

private native int elementsNum(long tensorPtr);

private native void free(long tensorPtr);
}

+ 76
- 0
mindspore/lite/java/src/java/cn/huawei/mindspore/Model.java View File

@@ -0,0 +1,76 @@
package cn.huawei.mindspore;

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.util.Log;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;

public class Model {
static {
System.loadLibrary("mindspore-lite-jni");
}

private long modelPtr;

public Model() {
this.modelPtr = 0;
}

public long getModelPtr() {
return modelPtr;
}

public void setModelPtr(long modelPtr) {
this.modelPtr = modelPtr;
}

public boolean loadModel(Context context, String modelName) {
FileInputStream fis = null;
AssetFileDescriptor fileDescriptor = null;
boolean ret = false;
try {
fileDescriptor = context.getAssets().openFd(modelName);
fis = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = fis.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLen = fileDescriptor.getDeclaredLength();
MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLen);
this.modelPtr = loadModel(buffer);
ret = this.modelPtr != 0;
} catch (IOException e) {
this.modelPtr = 0;
Log.e("MS_LITE", "Load model failed: " + e.getMessage());
ret = false;
} finally {
if (null != fis) {
try {
fis.close();
} catch (IOException e) {
Log.e("MS_LITE", "Close file failed: " + e.getMessage());
}
}
if (null != fileDescriptor) {
try {
fis.close();
} catch (IOException e) {
Log.e("MS_LITE", "Close fileDescriptor failed: " + e.getMessage());
}
}
}
return ret;
}

public void free() {
this.free(this.modelPtr);
this.modelPtr = 0;
}

private native long loadModel(MappedByteBuffer buffer);

private native void free(long modelPtr);
}

+ 43
- 0
mindspore/lite/java/src/java/cn/huawei/mindspore/context/Context.java View File

@@ -0,0 +1,43 @@
package cn.huawei.mindspore.context;

public class Context {
private long contextPtr;

public Context() {
this.contextPtr = 0;
}

public long getContextPtr() {
return contextPtr;
}

public void setContextPtr(long contextPtr) {
this.contextPtr = contextPtr;
}

public boolean init(int deviceType, int threadNum, int cpuBindMode) {
this.contextPtr = createContext(deviceType, threadNum, cpuBindMode);
return this.contextPtr != 0;
}

public boolean init(int deviceType, int threadNum) {
return init(deviceType, threadNum, CpuBindMode.MID_CPU);
}

public boolean init(int deviceType) {
return init(deviceType, 2);
}

public boolean init() {
return init(DeviceType.DT_CPU);
}

public void free() {
this.free(this.contextPtr);
this.contextPtr = 0;
}

private native long createContext(int deviceType, int threadNum, int cpuBindMode);

private native void free(long contextPtr);
}

+ 7
- 0
mindspore/lite/java/src/java/cn/huawei/mindspore/context/CpuBindMode.java View File

@@ -0,0 +1,7 @@
package cn.huawei.mindspore.context;

public class CpuBindMode {
public static final int MID_CPU = -1;
public static final int HIGHER_CPU = 1;
public static final int NO_BIND = 0;
}

+ 7
- 0
mindspore/lite/java/src/java/cn/huawei/mindspore/context/DeviceType.java View File

@@ -0,0 +1,7 @@
package cn.huawei.mindspore.context;

public class DeviceType {
public static final int DT_CPU = 0;
public static final int DT_GPU = 1;
public static final int DT_NPU = 2;
}

+ 18
- 0
mindspore/lite/java/src/native/CMakeLists.txt View File

@@ -0,0 +1,18 @@
cmake_minimum_required(VERSION 3.4.1)

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/libs/${ANDROID_ABI})

add_library(mindspore-lite-jni SHARED
${CMAKE_CURRENT_SOURCE_DIR}/src/common/ms_log.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/common/jni_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/context.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/ms_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/lite_session.cpp
)

find_library(log-lib log)

target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib} )

+ 37
- 0
mindspore/lite/java/src/native/common/jni_utils.cpp View File

@@ -0,0 +1,37 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/jni_utils.h"
#include <cstring>

char *JstringToChar(JNIEnv *env, jstring jstr) {
char *rtn = NULL;
jclass clsstring = env->FindClass("java/lang/String");
jstring strencode = env->NewStringUTF("GB2312");
jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
jbyteArray barr = (jbyteArray)env->CallObjectMethod(jstr, mid, strencode);
jsize alen = env->GetArrayLength(barr);
jbyte *ba = env->GetByteArrayElements(barr, JNI_FALSE);
if (alen > 0) {
rtn = new char[alen + 1];
memcpy(rtn, ba, alen);
rtn[alen] = 0;
}
env->ReleaseByteArrayElements(barr, ba, 0);
return rtn;
}

+ 26
- 0
mindspore/lite/java/src/native/common/jni_utils.h View File

@@ -0,0 +1,26 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_JAVA_SRC_COMMON_JNI_UTILS_H
#define MINDSPORE_LITE_JAVA_SRC_COMMON_JNI_UTILS_H

#include <jni.h>

char *JstringToChar(JNIEnv *env, jstring jstr);

#endif // MINDSPORE_LITE_JAVA_SRC_COMMON_JNI_UTILS_H

+ 19
- 0
mindspore/lite/java/src/native/common/ms_log.cpp View File

@@ -0,0 +1,19 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/ms_log.h"

+ 36
- 0
mindspore/lite/java/src/native/common/ms_log.h View File

@@ -0,0 +1,36 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_JAVA_SRC_COMMON_MS_LOG_H
#define MINDSPORE_LITE_JAVA_SRC_COMMON_MS_LOG_H

#include <android/log.h>
#include <unistd.h>

#define TAG "MS_LITE"

#define MS_LOGD(fmt, args...) \
{ __android_log_print(ANDROID_LOG_DEBUG, TAG, "|%d|%s[%d]|: " fmt, getpid(), __func__, __LINE__, ##args); }

#define MS_LOGE(fmt, args...) \
{ __android_log_print(ANDROID_LOG_ERROR, TAG, "|%d|%s[%d]|: " fmt, getpid(), __func__, __LINE__, ##args); }

#define MS_LOGI(fmt, args...) \
{ __android_log_print(ANDROID_LOG_INFO, TAG, "|%d|%s[%d]|: " fmt, getpid(), __func__, __LINE__, ##args); }

#endif // MINDSPORE_LITE_JAVA_SRC_COMMON_MS_LOG_H

+ 69
- 0
mindspore/lite/java/src/native/runtime/context.cpp View File

@@ -0,0 +1,69 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <jni.h>
#include "common/ms_log.h"
#include "include/context.h"

extern "C" JNIEXPORT jlong JNICALL Java_cn_huawei_mindspore_context_Context_createContext(JNIEnv *env, jobject thiz,
jint device_type,
jint thread_num,
jint cpu_bind_mode) {
auto *context = new mindspore::lite::Context();
switch (device_type) {
case 0:
context->device_ctx_.type = mindspore::lite::DT_CPU;
break;
case 1:
context->device_ctx_.type = mindspore::lite::DT_GPU;
break;
case 2:
context->device_ctx_.type = mindspore::lite::DT_NPU;
break;
default:
MS_LOGE("Invalid device_type : %d", device_type);
return (jlong)context;
}
switch (cpu_bind_mode) {
case -1:
context->cpu_bind_mode_ = mindspore::lite::MID_CPU;
break;
case 0:
context->cpu_bind_mode_ = mindspore::lite::NO_BIND;
break;
case 1:
context->cpu_bind_mode_ = mindspore::lite::HIGHER_CPU;
break;
default:
MS_LOGE("Invalid cpu_bind_mode : %d", cpu_bind_mode);
return (jlong)context;
}
context->thread_num_ = thread_num;
return (jlong)context;
}

extern "C" JNIEXPORT void JNICALL Java_cn_huawei_mindspore_context_Context_free(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *pointer = static_cast<void *>(context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return;
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
delete (lite_context_ptr);
}

+ 196
- 0
mindspore/lite/java/src/native/runtime/lite_session.cpp View File

@@ -0,0 +1,196 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <jni.h>
#include "common/ms_log.h"
#include "common/jni_utils.h"
#include "include/lite_session.h"
#include "include/errorcode.h"

extern "C" JNIEXPORT jlong JNICALL Java_cn_huawei_mindspore_LiteSession_createSession(JNIEnv *env, jobject thiz,
jlong context_ptr) {
auto *pointer = static_cast<void *>(context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_context_ptr = static_cast<mindspore::lite::Context *>(pointer);
auto session = mindspore::session::LiteSession::CreateSession(lite_context_ptr);
if (session == nullptr) {
MS_LOGE("CreateSession failed");
return jlong(nullptr);
}
return jlong(session);
}

extern "C" JNIEXPORT jboolean JNICALL Java_cn_huawei_mindspore_LiteSession_compileGraph(JNIEnv *env, jobject thiz,
jlong session_ptr,
jlong model_ptr) {
auto *session_pointer = static_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(session_pointer);
auto *model_pointer = static_cast<void *>(model_ptr);
if (model_pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_model_ptr = static_cast<mindspore::lite::Model *>(model_pointer);

auto ret = lite_session_ptr->CompileGraph(lite_model_ptr);
return (jboolean)(ret == mindspore::lite::RET_OK);
}

extern "C" JNIEXPORT void JNICALL Java_cn_huawei_mindspore_LiteSession_bindThread(JNIEnv *env, jobject thiz,
jlong session_ptr, jboolean if_bind) {
auto *pointer = static_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
lite_session_ptr->BindThread(if_bind);
}

extern "C" JNIEXPORT jboolean JNICALL Java_cn_huawei_mindspore_LiteSession_runGraph(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *pointer = static_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return (jboolean) false;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto ret = lite_session_ptr->RunGraph();
return (jboolean)(ret == mindspore::lite::RET_OK);
}

extern "C" JNIEXPORT jobject JNICALL Java_cn_huawei_mindspore_LiteSession_getInputs(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");

jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
auto *pointer = static_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetInputs();
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
}
return ret;
}

extern "C" JNIEXPORT jobject JNICALL Java_cn_huawei_mindspore_LiteSession_getInputsByName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");

jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
auto *pointer = static_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetInputsByName(JstringToChar(env, node_name));
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
}
return ret;
}

extern "C" JNIEXPORT jobject JNICALL Java_cn_huawei_mindspore_LiteSession_getOutputs(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
jmethodID hash_map_put =
env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
auto *pointer = static_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputs();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
for (auto output_iter : outputs) {
auto node_name = output_iter.first;
auto ms_tensors = output_iter.second;
jobject vec = env->NewObject(array_list, array_list_construct);
for (auto ms_tensor : ms_tensors) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
env->CallBooleanMethod(vec, array_list_add, tensor_addr);
}
env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), vec);
}
return hash_map;
}

extern "C" JNIEXPORT jobject JNICALL Java_cn_huawei_mindspore_LiteSession_getOutputsByName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");

jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
auto *pointer = static_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetOutputsByName(JstringToChar(env, node_name));
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
}
return ret;
}

extern "C" JNIEXPORT void JNICALL Java_cn_huawei_mindspore_LiteSession_free(JNIEnv *env, jobject thiz,
jlong session_ptr) {
auto *pointer = static_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
delete (lite_session_ptr);
}

+ 50
- 0
mindspore/lite/java/src/native/runtime/model.cpp View File

@@ -0,0 +1,50 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <jni.h>
#include "common/ms_log.h"
#include "include/model.h"

extern "C" JNIEXPORT jlong JNICALL Java_cn_huawei_mindspore_Model_loadModel(JNIEnv *env, jobject thiz, jobject buffer) {
MS_LOGD("Start getting buffer from java");
if (buffer == nullptr) {
MS_LOGE("Buffer from java is nullptr");
return reinterpret_cast<jlong>(nullptr);
}
jlong buffer_len = env->GetDirectBufferCapacity(buffer);
auto *model_buffer = static_cast<char *>(env->GetDirectBufferAddress(buffer));

MS_LOGD("Start Loading model");
auto model = mindspore::lite::Model::Import(model_buffer, buffer_len);
// env->DeleteLocalRef(*(jobject *)model_buffer);
if (model == nullptr) {
MS_LOGE("Import model failed");
return reinterpret_cast<jlong>(nullptr);
}
return reinterpret_cast<jlong>(model);
}

extern "C" JNIEXPORT void JNICALL Java_cn_huawei_mindspore_Model_free(JNIEnv *env, jobject thiz, jlong model_ptr) {
auto *pointer = static_cast<void *>(model_ptr);
if (pointer == nullptr) {
MS_LOGE("Model pointer from java is nullptr");
return;
}
auto *lite_model_ptr = static_cast<mindspore::lite::Model *>(pointer);
delete (lite_model_ptr);
}

+ 172
- 0
mindspore/lite/java/src/native/runtime/ms_tensor.cpp View File

@@ -0,0 +1,172 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <jni.h>
#include "common/ms_log.h"
#include "include/ms_tensor.h"
#include "ir/dtype/type_id.h"

extern "C" JNIEXPORT jlong JNICALL Java_cn_huawei_mindspore_MSTensor_createMSTensor(JNIEnv *env, jobject thiz,
jint data_type, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
std::vector<int> local_shape(shape_len);
for (size_t i = 0; i < shape_len; i++) {
local_shape[i] = local_shape_arr[i];
}
auto *ms_tensor = mindspore::tensor::MSTensor::CreateTensor(mindspore::TypeId(data_type), local_shape);
env->ReleaseIntArrayElements(shape, local_shape_arr, JNI_ABORT);
if (ms_tensor == nullptr) {
MS_LOGE("CreateTensor failed");
return reinterpret_cast<jlong>(nullptr);
}
return reinterpret_cast<jlong>(ms_tensor);
}

extern "C" JNIEXPORT jintArray JNICALL Java_cn_huawei_mindspore_MSTensor_getShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewIntArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto local_shape = ms_tensor_ptr->shape();
auto shape_size = local_shape.size();
jintArray shape = env->NewIntArray(shape_size);
auto *tmp = new jint[shape_size];
for (size_t i = 0; i < shape_size; i++) {
tmp[i] = local_shape.at(i);
}
delete[](tmp);
env->SetIntArrayRegion(shape, 0, shape_size, tmp);
return shape;
}

extern "C" JNIEXPORT jboolean JNICALL Java_cn_huawei_mindspore_MSTensor_setShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
std::vector<int> local_shape(shape_len);
for (size_t i = 0; i < shape_len; i++) {
local_shape[i] = local_shape_arr[i];
}
auto ret = ms_tensor_ptr->set_shape(local_shape);
return ret == shape_len;
}

extern "C" JNIEXPORT jint JNICALL Java_cn_huawei_mindspore_MSTensor_getDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
return jint(ms_tensor_ptr->data_type());
}

extern "C" JNIEXPORT jboolean JNICALL Java_cn_huawei_mindspore_MSTensor_setDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jint data_type) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto ret = ms_tensor_ptr->set_data_type(mindspore::TypeId(data_type));
return ret == data_type;
}

extern "C" JNIEXPORT jbyteArray JNICALL Java_cn_huawei_mindspore_MSTensor_getData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewByteArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto *local_data = static_cast<jbyte *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewByteArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewByteArray(local_data_size);
env->SetByteArrayRegion(ret, 0, local_data_size, local_data);
return ret;
}

extern "C" JNIEXPORT jboolean JNICALL Java_cn_huawei_mindspore_MSTensor_setData(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jbyteArray data,
jlong data_len) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
if (data_len != ms_tensor_ptr->Size()) {
MS_LOGE("data_len(%ld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->Size());
return static_cast<jboolean>(false);
}
jboolean is_copy = false;
auto *data_arr = env->GetByteArrayElements(data, &is_copy);
auto *local_data = ms_tensor_ptr->MutableData();
memcpy(local_data, data_arr, data_len);
return static_cast<jboolean>(true);
}

extern "C" JNIEXPORT jlong JNICALL Java_cn_huawei_mindspore_MSTensor_size(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return 0;
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
return ms_tensor_ptr->Size();
}

extern "C" JNIEXPORT jint JNICALL Java_cn_huawei_mindspore_MSTensor_elementsNum(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return 0;
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
return ms_tensor_ptr->ElementsNum();
}

extern "C" JNIEXPORT void JNICALL Java_cn_huawei_mindspore_MSTensor_free(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = static_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return;
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
delete (ms_tensor_ptr);
}

+ 4
- 5
mindspore/lite/src/executor.cc View File

@@ -37,12 +37,10 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
for (auto *kernel : kernels) {
MS_ASSERT(nullptr != kernel);
session::CallBackParam callbackParam;
callbackParam.name_callback_param = kernel->Name();
callbackParam.type_callback_param = kernel->type_str();

if (before != nullptr) {
if (!before(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()), callbackParam)) {
if (!before(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()),
{kernel->Name(), kernel->type_str()})) {
MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->Name();
}
}
@@ -53,7 +51,8 @@ int Executor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Ten
}

if (after != nullptr) {
if (!after(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()), callbackParam)) {
if (!after(PackToMSTensors(kernel->GetInputs()), PackToMSTensors(kernel->GetOutputs()),
{kernel->Name(), kernel->type_str()})) {
MS_LOG(ERROR) << "run kernel after_callback failed, name: " << kernel->Name();
}
}


Loading…
Cancel
Save