diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java index 3e4274e31e..0dcf8d2049 100644 --- a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/MSTensor.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Huawei Technologies Co., Ltd - * + *
* Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + *
* http://www.apache.org/licenses/LICENSE-2.0 - * + *
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -18,6 +18,8 @@ package com.mindspore.lite;
import android.util.Log;
+import java.nio.ByteBuffer;
+
public class MSTensor {
private long tensorPtr;
@@ -29,7 +31,7 @@ public class MSTensor {
this.tensorPtr = tensorPtr;
}
- public boolean init (int dataType, int[] shape) {
+ public boolean init(int dataType, int[] shape) {
this.tensorPtr = createMSTensor(dataType, shape, shape.length);
return this.tensorPtr != 0;
}
@@ -50,18 +52,30 @@ public class MSTensor {
this.setDataType(this.tensorPtr, dataType);
}
- public byte[] getData() {
- return this.getData(this.tensorPtr);
+ public byte[] getBtyeData() {
+ return this.getByteData(this.tensorPtr);
}
public float[] getFloatData() {
- return decodeBytes(this.getData(this.tensorPtr));
+ return this.getFloatData(this.tensorPtr);
+ }
+
+ public int[] getIntData() {
+ return this.getIntData(this.tensorPtr);
+ }
+
+ public long[] getLongData() {
+ return this.getLongData(this.tensorPtr);
}
public void setData(byte[] data) {
this.setData(this.tensorPtr, data, data.length);
}
+ public void setData(ByteBuffer data) {
+ this.setByteBufferData(this.tensorPtr, data);
+ }
+
public long size() {
return this.size(this.tensorPtr);
}
@@ -82,13 +96,13 @@ public class MSTensor {
}
int size = bytes.length / 4;
float[] ret = new float[size];
- for (int i = 0; i < size; i=i+4) {
+ for (int i = 0; i < size; i = i + 4) {
int accNum = 0;
accNum = accNum | (bytes[i] & 0xff) << 0;
- accNum = accNum | (bytes[i+1] & 0xff) << 8;
- accNum = accNum | (bytes[i+2] & 0xff) << 16;
- accNum = accNum | (bytes[i+3] & 0xff) << 24;
- ret[i/4] = Float.intBitsToFloat(accNum);
+ accNum = accNum | (bytes[i + 1] & 0xff) << 8;
+ accNum = accNum | (bytes[i + 2] & 0xff) << 16;
+ accNum = accNum | (bytes[i + 3] & 0xff) << 24;
+ ret[i / 4] = Float.intBitsToFloat(accNum);
}
return ret;
}
@@ -103,10 +117,18 @@ public class MSTensor {
private native boolean setDataType(long tensorPtr, int dataType);
- private native byte[] getData(long tensorPtr);
+ private native byte[] getByteData(long tensorPtr);
+
+ private native long[] getLongData(long tensorPtr);
+
+ private native int[] getIntData(long tensorPtr);
+
+ private native float[] getFloatData(long tensorPtr);
private native boolean setData(long tensorPtr, byte[] data, long dataLen);
+ private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer);
+
private native long size(long tensorPtr);
private native int elementsNum(long tensorPtr);
diff --git a/mindspore/lite/java/native/runtime/ms_tensor.cpp b/mindspore/lite/java/native/runtime/ms_tensor.cpp
index 8f3a607f85..d71aca9f41 100644
--- a/mindspore/lite/java/native/runtime/ms_tensor.cpp
+++ b/mindspore/lite/java/native/runtime/ms_tensor.cpp
@@ -99,8 +99,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataTy
return ret == data_type;
}
-extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData(JNIEnv *env, jobject thiz,
- jlong tensor_ptr) {
+extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByteData(JNIEnv *env, jobject thiz,
+ jlong tensor_ptr) {
auto *pointer = reinterpret_cast