Browse Source

!6217 [MS][LITE]change java api

Merge pull request !6217 from yeyunpeng2020/java
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9d723b58c0
4 changed files with 1 additions and 120 deletions
  1. +0
    -19
      mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java
  2. +0
    -19
      mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java
  3. +1
    -33
      mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp
  4. +0
    -49
      mindspore/lite/java/java/app/src/main/native/runtime/ms_tensor.cpp

+ 0
- 19
mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/LiteSession.java View File

@@ -76,23 +76,6 @@ public class LiteSession {
return tensors;
}

public Map<String, List<MSTensor>> getOutputMapByNode() {
Map<String, List<Long>> ret = this.getOutputMapByNode(this.sessionPtr);
Map<String, List<MSTensor>> tensorMap = new HashMap<>();
Set<Map.Entry<String, List<Long>>> entrySet = ret.entrySet();
for (Map.Entry<String, List<Long>> entry : entrySet) {
String name = entry.getKey();
List<Long> msTensorAddrs = entry.getValue();
ArrayList<MSTensor> msTensors = new ArrayList<>();
for (Long msTensorAddr : msTensorAddrs) {
MSTensor msTensor = new MSTensor(msTensorAddr);
msTensors.add(msTensor);
}
tensorMap.put(name, msTensors);
}
return tensorMap;
}

public List<MSTensor> getOutputsByNodeName(String nodeName) {
List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
@@ -141,8 +124,6 @@ public class LiteSession {

private native List<Long> getInputsByName(long sessionPtr, String nodeName);

private native Map<String, List<Long>> getOutputMapByNode(long sessionPtr);

private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);

private native Map<String, Long> getOutputMapByTensor(long sessionPtr);


+ 0
- 19
mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java View File

@@ -31,27 +31,14 @@ public class MSTensor {
this.tensorPtr = tensorPtr;
}

public boolean init(int dataType, int[] shape) {
this.tensorPtr = createMSTensor(dataType, shape, shape.length);
return this.tensorPtr != 0;
}

public int[] getShape() {
return this.getShape(this.tensorPtr);
}

public void setShape(int[] shape) {
this.setShape(this.tensorPtr, shape, shape.length);
}

public int getDataType() {
return this.getDataType(this.tensorPtr);
}

public void setDataType(int dataType) {
this.setDataType(this.tensorPtr, dataType);
}

public byte[] getByteData() {
return this.getByteData(this.tensorPtr);
}
@@ -107,16 +94,10 @@ public class MSTensor {
return ret;
}

private native long createMSTensor(int dataType, int[] shape, int shapeLen);

private native int[] getShape(long tensorPtr);

private native boolean setShape(long tensorPtr, int[] shape, int shapeLen);

private native int getDataType(long tensorPtr);

private native boolean setDataType(long tensorPtr, int dataType);

private native byte[] getByteData(long tensorPtr);

private native long[] getLongData(long tensorPtr);


+ 1
- 33
mindspore/lite/java/java/app/src/main/native/runtime/lite_session.cpp View File

@@ -126,38 +126,6 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
return ret;
}

extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByNode(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
jmethodID hash_map_put =
env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputMapByNode();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
for (auto output_iter : outputs) {
auto node_name = output_iter.first;
auto ms_tensors = output_iter.second;
jobject vec = env->NewObject(array_list, array_list_construct);
for (auto ms_tensor : ms_tensors) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
env->CallBooleanMethod(vec, array_list_add, tensor_addr);
}
env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), vec);
}
return hash_map;
}

extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
@@ -195,7 +163,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputMapByTensor();
auto outputs = lite_session_ptr->GetOutputs();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
for (auto output_iter : outputs) {


+ 0
- 49
mindspore/lite/java/java/app/src/main/native/runtime/ms_tensor.cpp View File

@@ -19,24 +19,6 @@
#include "include/ms_tensor.h"
#include "ir/dtype/type_id.h"

extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTensor(JNIEnv *env, jobject thiz,
jint data_type, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
std::vector<int> local_shape(shape_len);
for (size_t i = 0; i < shape_len; i++) {
local_shape[i] = local_shape_arr[i];
}
auto *ms_tensor = mindspore::tensor::MSTensor::CreateTensor(mindspore::TypeId(data_type), local_shape);
env->ReleaseIntArrayElements(shape, local_shape_arr, JNI_ABORT);
if (ms_tensor == nullptr) {
MS_LOGE("CreateTensor failed");
return reinterpret_cast<jlong>(nullptr);
}
return reinterpret_cast<jlong>(ms_tensor);
}

extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
@@ -57,25 +39,6 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape
return shape;
}

extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
std::vector<int> local_shape(shape_len);
for (size_t i = 0; i < shape_len; i++) {
local_shape[i] = local_shape_arr[i];
}
auto ret = ms_tensor_ptr->set_shape(local_shape);
return ret == shape_len;
}

extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
@@ -87,18 +50,6 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(J
return jint(ms_tensor_ptr->data_type());
}

extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jint data_type) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto ret = ms_tensor_ptr->set_data_type(mindspore::TypeId(data_type));
return ret == data_type;
}

extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByteData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);


Loading…
Cancel
Save