|
|
@@ -64,7 +64,7 @@ public class TrainSession { |
|
|
|
|
|
|
|
|
public MSTensor getInputsByTensorName(String tensorName) { |
|
|
public MSTensor getInputsByTensorName(String tensorName) { |
|
|
Long tensorAddr = this.getInputsByTensorName(this.sessionPtr, tensorName); |
|
|
Long tensorAddr = this.getInputsByTensorName(this.sessionPtr, tensorName); |
|
|
if(tensorAddr == null) { |
|
|
|
|
|
|
|
|
if (tensorAddr == null) { |
|
|
return null; |
|
|
return null; |
|
|
} |
|
|
} |
|
|
MSTensor msTensor = new MSTensor(tensorAddr); |
|
|
MSTensor msTensor = new MSTensor(tensorAddr); |
|
|
@@ -99,7 +99,7 @@ public class TrainSession { |
|
|
|
|
|
|
|
|
public MSTensor getOutputByTensorName(String tensorName) { |
|
|
public MSTensor getOutputByTensorName(String tensorName) { |
|
|
Long tensorAddr = getOutputByTensorName(this.sessionPtr, tensorName); |
|
|
Long tensorAddr = getOutputByTensorName(this.sessionPtr, tensorName); |
|
|
if(tensorAddr == null) { |
|
|
|
|
|
|
|
|
if (tensorAddr == null) { |
|
|
return null; |
|
|
return null; |
|
|
} |
|
|
} |
|
|
return new MSTensor(tensorAddr); |
|
|
return new MSTensor(tensorAddr); |
|
|
|