Merge pull request !6217 from yeyunpeng2020/javatags/v1.0.0
| @@ -76,23 +76,6 @@ public class LiteSession { | |||
| return tensors; | |||
| } | |||
| public Map<String, List<MSTensor>> getOutputMapByNode() { | |||
| Map<String, List<Long>> ret = this.getOutputMapByNode(this.sessionPtr); | |||
| Map<String, List<MSTensor>> tensorMap = new HashMap<>(); | |||
| Set<Map.Entry<String, List<Long>>> entrySet = ret.entrySet(); | |||
| for (Map.Entry<String, List<Long>> entry : entrySet) { | |||
| String name = entry.getKey(); | |||
| List<Long> msTensorAddrs = entry.getValue(); | |||
| ArrayList<MSTensor> msTensors = new ArrayList<>(); | |||
| for (Long msTensorAddr : msTensorAddrs) { | |||
| MSTensor msTensor = new MSTensor(msTensorAddr); | |||
| msTensors.add(msTensor); | |||
| } | |||
| tensorMap.put(name, msTensors); | |||
| } | |||
| return tensorMap; | |||
| } | |||
| public List<MSTensor> getOutputsByNodeName(String nodeName) { | |||
| List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName); | |||
| ArrayList<MSTensor> tensors = new ArrayList<>(); | |||
| @@ -141,8 +124,6 @@ public class LiteSession { | |||
| private native List<Long> getInputsByName(long sessionPtr, String nodeName); | |||
| private native Map<String, List<Long>> getOutputMapByNode(long sessionPtr); | |||
| private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName); | |||
| private native Map<String, Long> getOutputMapByTensor(long sessionPtr); | |||
| @@ -31,27 +31,14 @@ public class MSTensor { | |||
| this.tensorPtr = tensorPtr; | |||
| } | |||
| public boolean init(int dataType, int[] shape) { | |||
| this.tensorPtr = createMSTensor(dataType, shape, shape.length); | |||
| return this.tensorPtr != 0; | |||
| } | |||
| public int[] getShape() { | |||
| return this.getShape(this.tensorPtr); | |||
| } | |||
| public void setShape(int[] shape) { | |||
| this.setShape(this.tensorPtr, shape, shape.length); | |||
| } | |||
| public int getDataType() { | |||
| return this.getDataType(this.tensorPtr); | |||
| } | |||
| public void setDataType(int dataType) { | |||
| this.setDataType(this.tensorPtr, dataType); | |||
| } | |||
| public byte[] getByteData() { | |||
| return this.getByteData(this.tensorPtr); | |||
| } | |||
| @@ -107,16 +94,10 @@ public class MSTensor { | |||
| return ret; | |||
| } | |||
| private native long createMSTensor(int dataType, int[] shape, int shapeLen); | |||
| private native int[] getShape(long tensorPtr); | |||
| private native boolean setShape(long tensorPtr, int[] shape, int shapeLen); | |||
| private native int getDataType(long tensorPtr); | |||
| private native boolean setDataType(long tensorPtr, int dataType); | |||
| private native byte[] getByteData(long tensorPtr); | |||
| private native long[] getLongData(long tensorPtr); | |||
| @@ -126,38 +126,6 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu | |||
| return ret; | |||
| } | |||
| extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByNode(JNIEnv *env, jobject thiz, | |||
| jlong session_ptr) { | |||
| jclass hash_map_clazz = env->FindClass("java/util/HashMap"); | |||
| jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V"); | |||
| jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct); | |||
| jmethodID hash_map_put = | |||
| env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); | |||
| auto *pointer = reinterpret_cast<void *>(session_ptr); | |||
| if (pointer == nullptr) { | |||
| MS_LOGE("Session pointer from java is nullptr"); | |||
| return hash_map; | |||
| } | |||
| auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | |||
| auto outputs = lite_session_ptr->GetOutputMapByNode(); | |||
| jclass long_object = env->FindClass("java/lang/Long"); | |||
| jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V"); | |||
| jclass array_list = env->FindClass("java/util/ArrayList"); | |||
| jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V"); | |||
| jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); | |||
| for (auto output_iter : outputs) { | |||
| auto node_name = output_iter.first; | |||
| auto ms_tensors = output_iter.second; | |||
| jobject vec = env->NewObject(array_list, array_list_construct); | |||
| for (auto ms_tensor : ms_tensors) { | |||
| jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor)); | |||
| env->CallBooleanMethod(vec, array_list_add, tensor_addr); | |||
| } | |||
| env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), vec); | |||
| } | |||
| return hash_map; | |||
| } | |||
| extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz, | |||
| jlong session_ptr, | |||
| jstring node_name) { | |||
| @@ -195,7 +163,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp | |||
| return hash_map; | |||
| } | |||
| auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer); | |||
| auto outputs = lite_session_ptr->GetOutputMapByTensor(); | |||
| auto outputs = lite_session_ptr->GetOutputs(); | |||
| jclass long_object = env->FindClass("java/lang/Long"); | |||
| jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V"); | |||
| for (auto output_iter : outputs) { | |||
| @@ -19,24 +19,6 @@ | |||
| #include "include/ms_tensor.h" | |||
| #include "ir/dtype/type_id.h" | |||
| extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTensor(JNIEnv *env, jobject thiz, | |||
| jint data_type, jintArray shape, | |||
| jint shape_len) { | |||
| jboolean is_copy = false; | |||
| jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy); | |||
| std::vector<int> local_shape(shape_len); | |||
| for (size_t i = 0; i < shape_len; i++) { | |||
| local_shape[i] = local_shape_arr[i]; | |||
| } | |||
| auto *ms_tensor = mindspore::tensor::MSTensor::CreateTensor(mindspore::TypeId(data_type), local_shape); | |||
| env->ReleaseIntArrayElements(shape, local_shape_arr, JNI_ABORT); | |||
| if (ms_tensor == nullptr) { | |||
| MS_LOGE("CreateTensor failed"); | |||
| return reinterpret_cast<jlong>(nullptr); | |||
| } | |||
| return reinterpret_cast<jlong>(ms_tensor); | |||
| } | |||
| extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape(JNIEnv *env, jobject thiz, | |||
| jlong tensor_ptr) { | |||
| auto *pointer = reinterpret_cast<void *>(tensor_ptr); | |||
| @@ -57,25 +39,6 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape | |||
| return shape; | |||
| } | |||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(JNIEnv *env, jobject thiz, | |||
| jlong tensor_ptr, jintArray shape, | |||
| jint shape_len) { | |||
| jboolean is_copy = false; | |||
| jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy); | |||
| auto *pointer = reinterpret_cast<void *>(tensor_ptr); | |||
| if (pointer == nullptr) { | |||
| MS_LOGE("Tensor pointer from java is nullptr"); | |||
| return static_cast<jboolean>(false); | |||
| } | |||
| auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer); | |||
| std::vector<int> local_shape(shape_len); | |||
| for (size_t i = 0; i < shape_len; i++) { | |||
| local_shape[i] = local_shape_arr[i]; | |||
| } | |||
| auto ret = ms_tensor_ptr->set_shape(local_shape); | |||
| return ret == shape_len; | |||
| } | |||
| extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(JNIEnv *env, jobject thiz, | |||
| jlong tensor_ptr) { | |||
| auto *pointer = reinterpret_cast<void *>(tensor_ptr); | |||
| @@ -87,18 +50,6 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(J | |||
| return jint(ms_tensor_ptr->data_type()); | |||
| } | |||
| extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataType(JNIEnv *env, jobject thiz, | |||
| jlong tensor_ptr, jint data_type) { | |||
| auto *pointer = reinterpret_cast<void *>(tensor_ptr); | |||
| if (pointer == nullptr) { | |||
| MS_LOGE("Tensor pointer from java is nullptr"); | |||
| return static_cast<jboolean>(false); | |||
| } | |||
| auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer); | |||
| auto ret = ms_tensor_ptr->set_data_type(mindspore::TypeId(data_type)); | |||
| return ret == data_type; | |||
| } | |||
| extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByteData(JNIEnv *env, jobject thiz, | |||
| jlong tensor_ptr) { | |||
| auto *pointer = reinterpret_cast<void *>(tensor_ptr); | |||