Browse Source

!7761 change GetInputsByName to GetInputByTensorName

Merge pull request !7761 from yeyunpeng2020/java
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9a6be0f48a
2 changed files with 15 additions and 23 deletions
  1. +10
    -8
      mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java
  2. +5
    -15
      mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp

+ 10
- 8
mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java View File

@@ -66,14 +66,13 @@ public class LiteSession {
return tensors; return tensors;
} }


public List<MSTensor> getInputsByName(String nodeName) {
List<Long> ret = this.getInputsByName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
public MSTensor getInputsByTensorName(String tensorName) {
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
} }
return tensors;
MSTensor msTensor = new MSTensor(tensor_addr);
return msTensor;
} }


public List<MSTensor> getOutputsByNodeName(String nodeName) { public List<MSTensor> getOutputsByNodeName(String nodeName) {
@@ -104,6 +103,9 @@ public class LiteSession {


public MSTensor getOutputByTensorName(String tensorName) { public MSTensor getOutputByTensorName(String tensorName) {
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName); Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
}
return new MSTensor(tensor_addr); return new MSTensor(tensor_addr);
} }


@@ -130,7 +132,7 @@ public class LiteSession {


private native List<Long> getInputs(long sessionPtr); private native List<Long> getInputs(long sessionPtr);


private native List<Long> getInputsByName(long sessionPtr, String nodeName);
private native Long getInputsByTensorName(long sessionPtr, String tensorName);


private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName); private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);




+ 5
- 15
mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp View File

@@ -102,27 +102,17 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
return ret; return ret;
} }


extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env,
jobject thiz,
jlong session_ptr,
jstring tensor_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");

jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring tensor_name) {
auto *pointer = reinterpret_cast<void *>(session_ptr); auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) { if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr"); MS_LOGE("Session pointer from java is nullptr");
return ret;
return jlong(nullptr);
} }
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto input = lite_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name)); auto input = lite_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name));
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
return ret;
return jlong(input);
} }


extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz, extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,


Loading…
Cancel
Save