|
|
|
@@ -55,19 +55,19 @@ public class TrainSession { |
|
|
|
public List<MSTensor> getInputs() { |
|
|
|
List<Long> ret = this.getInputs(this.sessionPtr); |
|
|
|
ArrayList<MSTensor> tensors = new ArrayList<MSTensor>(); |
|
|
|
for (Long ms_tensor_addr : ret) { |
|
|
|
MSTensor msTensor = new MSTensor(ms_tensor_addr); |
|
|
|
for (Long msTensorAddr : ret) { |
|
|
|
MSTensor msTensor = new MSTensor(msTensorAddr); |
|
|
|
tensors.add(msTensor); |
|
|
|
} |
|
|
|
return tensors; |
|
|
|
} |
|
|
|
|
|
|
|
public MSTensor getInputsByTensorName(String tensorName) { |
|
|
|
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName); |
|
|
|
if(tensor_addr == null){ |
|
|
|
Long tensorAddr = this.getInputsByTensorName(this.sessionPtr, tensorName); |
|
|
|
if(tensorAddr == null) { |
|
|
|
return null; |
|
|
|
} |
|
|
|
MSTensor msTensor = new MSTensor(tensor_addr); |
|
|
|
MSTensor msTensor = new MSTensor(tensorAddr); |
|
|
|
return msTensor; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -98,11 +98,11 @@ public class TrainSession { |
|
|
|
} |
|
|
|
|
|
|
|
public MSTensor getOutputByTensorName(String tensorName) { |
|
|
|
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName); |
|
|
|
if(tensor_addr == null){ |
|
|
|
Long tensorAddr = getOutputByTensorName(this.sessionPtr, tensorName); |
|
|
|
if(tensorAddr == null) { |
|
|
|
return null; |
|
|
|
} |
|
|
|
return new MSTensor(tensor_addr); |
|
|
|
return new MSTensor(tensorAddr); |
|
|
|
} |
|
|
|
|
|
|
|
public void free() { |
|
|
|
@@ -111,11 +111,11 @@ public class TrainSession { |
|
|
|
} |
|
|
|
|
|
|
|
public boolean resize(List<MSTensor> inputs, int[][] dims) { |
|
|
|
long[] inputs_array = new long[inputs.size()]; |
|
|
|
long[] inputsArray = new long[inputs.size()]; |
|
|
|
for (int i = 0; i < inputs.size(); i++) { |
|
|
|
inputs_array[i] = inputs.get(i).getMSTensorPtr(); |
|
|
|
inputsArray[i] = inputs.get(i).getMSTensorPtr(); |
|
|
|
} |
|
|
|
return this.resize(this.sessionPtr, inputs_array, dims); |
|
|
|
return this.resize(this.sessionPtr, inputsArray, dims); |
|
|
|
} |
|
|
|
|
|
|
|
public boolean saveToFile(String modelFilename) { |
|
|
|
|