| @@ -59,19 +59,19 @@ public class LiteSession { | |||||
| public List<MSTensor> getInputs() { | public List<MSTensor> getInputs() { | ||||
| List<Long> ret = this.getInputs(this.sessionPtr); | List<Long> ret = this.getInputs(this.sessionPtr); | ||||
| ArrayList<MSTensor> tensors = new ArrayList<MSTensor>(); | ArrayList<MSTensor> tensors = new ArrayList<MSTensor>(); | ||||
| for (Long ms_tensor_addr : ret) { | |||||
| MSTensor msTensor = new MSTensor(ms_tensor_addr); | |||||
| for (Long msTensorAddr : ret) { | |||||
| MSTensor msTensor = new MSTensor(msTensorAddr); | |||||
| tensors.add(msTensor); | tensors.add(msTensor); | ||||
| } | } | ||||
| return tensors; | return tensors; | ||||
| } | } | ||||
| public MSTensor getInputsByTensorName(String tensorName) { | public MSTensor getInputsByTensorName(String tensorName) { | ||||
| Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName); | |||||
| if(tensor_addr == null){ | |||||
| Long tensorAddr = this.getInputsByTensorName(this.sessionPtr, tensorName); | |||||
| if (tensorAddr == null) { | |||||
| return null; | return null; | ||||
| } | } | ||||
| MSTensor msTensor = new MSTensor(tensor_addr); | |||||
| MSTensor msTensor = new MSTensor(tensorAddr); | |||||
| return msTensor; | return msTensor; | ||||
| } | } | ||||
| @@ -102,11 +102,11 @@ public class LiteSession { | |||||
| } | } | ||||
| public MSTensor getOutputByTensorName(String tensorName) { | public MSTensor getOutputByTensorName(String tensorName) { | ||||
| Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName); | |||||
| if(tensor_addr == null){ | |||||
| Long tensorAddr = getOutputByTensorName(this.sessionPtr, tensorName); | |||||
| if (tensorAddr == null) { | |||||
| return null; | return null; | ||||
| } | } | ||||
| return new MSTensor(tensor_addr); | |||||
| return new MSTensor(tensorAddr); | |||||
| } | } | ||||
| public void free() { | public void free() { | ||||
| @@ -115,11 +115,11 @@ public class LiteSession { | |||||
| } | } | ||||
| public boolean resize(List<MSTensor> inputs, int[][] dims) { | public boolean resize(List<MSTensor> inputs, int[][] dims) { | ||||
| long[] inputs_array = new long[inputs.size()]; | |||||
| long[] inputsArray = new long[inputs.size()]; | |||||
| for (int i = 0; i < inputs.size(); i++) { | for (int i = 0; i < inputs.size(); i++) { | ||||
| inputs_array[i] = inputs.get(i).getMSTensorPtr(); | |||||
| inputsArray[i] = inputs.get(i).getMSTensorPtr(); | |||||
| } | } | ||||
| return this.resize(this.sessionPtr, inputs_array, dims); | |||||
| return this.resize(this.sessionPtr, inputsArray, dims); | |||||
| } | } | ||||
| private native long createSession(long msConfigPtr); | private native long createSession(long msConfigPtr); | ||||
| @@ -29,7 +29,7 @@ public class MSConfig { | |||||
| } | } | ||||
| public boolean init(int deviceType, int threadNum, int cpuBindMode) { | public boolean init(int deviceType, int threadNum, int cpuBindMode) { | ||||
| this.msConfigPtr = createMSConfig(deviceType, threadNum, cpuBindMode ,false); | |||||
| this.msConfigPtr = createMSConfig(deviceType, threadNum, cpuBindMode, false); | |||||
| return this.msConfigPtr != 0; | return this.msConfigPtr != 0; | ||||
| } | } | ||||
| @@ -50,4 +50,5 @@ CastNPUKernel::~CastNPUKernel() { | |||||
| } | } | ||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Cast, NPUKernelCreator<CastNPUKernel>) | REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Cast, NPUKernelCreator<CastNPUKernel>) | ||||
| REG_KERNEL(kNPU, kNumberTypeInt32, PrimitiveType_Cast, NPUKernelCreator<CastNPUKernel>) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -53,5 +53,6 @@ GatherNPUKernel::~GatherNPUKernel() { | |||||
| op_ = nullptr; | op_ = nullptr; | ||||
| } | } | ||||
| } | } | ||||
| // NPU input index 0 datatype not support: 3(int32). | |||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Gather, NPUKernelCreator<GatherNPUKernel>) | REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Gather, NPUKernelCreator<GatherNPUKernel>) | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Shape; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ShapeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | int ShapeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | ||||
| OpParameter *opParameter) { | OpParameter *opParameter) { | ||||
| return RET_OK; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| int ShapeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | int ShapeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, | ||||
| @@ -27,9 +27,9 @@ int StridedSliceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, | |||||
| // Only onnx StridedSlice has 5 inputs, of which the 4th input is axes and the 5th input is strides. | // Only onnx StridedSlice has 5 inputs, of which the 4th input is axes and the 5th input is strides. | ||||
| if (inputs.size() == 5) { | if (inputs.size() == 5) { | ||||
| vector<int> axes; | vector<int> axes; | ||||
| size_t size = inputs[4]->shape()[0]; | |||||
| size_t size = inputs[3]->shape()[0]; | |||||
| axes.resize(size); | axes.resize(size); | ||||
| memcpy(axes.data(), inputs[4]->data_c(), sizeof(int) * size); | |||||
| memcpy(axes.data(), inputs[3]->data_c(), sizeof(int) * size); | |||||
| for (int i = 0; i < axes.size(); ++i) { | for (int i = 0; i < axes.size(); ++i) { | ||||
| if (i != axes[i]) { | if (i != axes[i]) { | ||||
| MS_LOG(ERROR) << "Does not support setting axis, so the axis must be continuous."; | MS_LOG(ERROR) << "Does not support setting axis, so the axis must be continuous."; | ||||
| @@ -77,4 +77,5 @@ StridedSliceNPUKernel::~StridedSliceNPUKernel() { | |||||
| } | } | ||||
| REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_StridedSlice, NPUKernelCreator<StridedSliceNPUKernel>) | REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_StridedSlice, NPUKernelCreator<StridedSliceNPUKernel>) | ||||
| REG_KERNEL(kNPU, kNumberTypeInt32, PrimitiveType_StridedSlice, NPUKernelCreator<StridedSliceNPUKernel>) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||