diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java index 64db530153..411a2bd1c0 100644 --- a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java @@ -66,14 +66,13 @@ public class LiteSession { return tensors; } - public List getInputsByName(String nodeName) { - List ret = this.getInputsByName(this.sessionPtr, nodeName); - ArrayList tensors = new ArrayList<>(); - for (Long msTensorAddr : ret) { - MSTensor msTensor = new MSTensor(msTensorAddr); - tensors.add(msTensor); + public MSTensor getInputsByTensorName(String tensorName) { + Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName); + if(tensor_addr == null){ + return null; } - return tensors; + MSTensor msTensor = new MSTensor(tensor_addr); + return msTensor; } public List getOutputsByNodeName(String nodeName) { @@ -104,6 +103,9 @@ public class LiteSession { public MSTensor getOutputByTensorName(String tensorName) { Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName); + if(tensor_addr == null){ + return null; + } return new MSTensor(tensor_addr); } @@ -130,7 +132,7 @@ public class LiteSession { private native List getInputs(long sessionPtr); - private native List getInputsByName(long sessionPtr, String nodeName); + private native Long getInputsByTensorName(long sessionPtr, String tensorName); private native List getOutputsByNodeName(long sessionPtr, String nodeName); diff --git a/mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp b/mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp index 544b67650a..faea92c6f7 100644 --- a/mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp +++ b/mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp @@ -102,27 +102,17 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu return ret; } -extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env, - jobject thiz, - jlong session_ptr, - jstring tensor_name) { - jclass array_list = env->FindClass("java/util/ArrayList"); - jmethodID array_list_construct = env->GetMethodID(array_list, "", "()V"); - jobject ret = env->NewObject(array_list, array_list_construct); - jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); - - jclass long_object = env->FindClass("java/lang/Long"); - jmethodID long_object_construct = env->GetMethodID(long_object, "", "(J)V"); +extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getInputsByTensorName(JNIEnv *env, jobject thiz, + jlong session_ptr, + jstring tensor_name) { auto *pointer = reinterpret_cast(session_ptr); if (pointer == nullptr) { MS_LOGE("Session pointer from java is nullptr"); - return ret; + return jlong(nullptr); } auto *lite_session_ptr = static_cast(pointer); auto input = lite_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name)); - jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input)); - env->CallBooleanMethod(ret, array_list_add, tensor_addr); - return ret; + return jlong(input); } extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,