diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs
index e6e60e32..402ffca1 100644
--- a/src/TensorFlowNET.Core/APIs/tf.graph.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs
@@ -24,7 +24,7 @@ namespace Tensorflow
return ops.get_default_graph();
}
- public static Graph Graph() => new Graph();
-
+ public static Graph Graph()
+ => new Graph();
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs
index 958106ee..7fcfdbd7 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs
@@ -55,11 +55,11 @@ namespace Tensorflow
return Status;
}
- public static Graph ImportFromPB(string file_path)
+ public static Graph ImportFromPB(string file_path, string name = null)
{
var graph = tf.Graph().as_default();
var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path));
- importer.import_graph_def(graph_def);
+ importer.import_graph_def(graph_def, name: name);
return graph;
}
}
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 4b7d641c..1ec4f6f3 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
1.14.0
- 0.10.2
+ 0.10.3
Haiping Chen, Meinrad Recheis
SciSharp STACK
true
@@ -16,9 +16,8 @@
https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#
Google's TensorFlow full binding in .NET Standard.
-Docs: https://tensorflownet.readthedocs.io
-Medium: https://medium.com/scisharp
- 0.10.2.0
+Docs: https://tensorflownet.readthedocs.io
+ 0.10.3.0
Changes since v0.9.0:
1. Added full connected Convolution Neural Network example.
@@ -29,11 +28,12 @@ Medium: https://medium.com/scisharp
6. Add StridedSliceGrad.
7. Add BatchMatMulGrad.
8. Upgrade NumSharp.
-9. Fix strided_slice_grad type convention error.
+9. Fix strided_slice_grad type convention error.
+10. Add AbsGrad.
7.2
- 0.10.2.0
+ 0.10.3.0
LICENSE
- false
+ true
true
Open.snk
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
index 628a7b5c..78f23155 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
@@ -80,6 +80,16 @@ namespace TensorFlowNET.Examples.ImageProcess
{
PrepareData();
+ #region For debug purpose
+
+ // predict images
+ Predict(null);
+
+ // load saved pb and test new images.
+ Test(null);
+
+ #endregion
+
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
with(tf.Session(graph), sess =>
@@ -708,14 +718,38 @@ namespace TensorFlowNET.Examples.ImageProcess
File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));
}
- public void Predict(Session sess)
+ public void Predict(Session sess_)
{
- throw new NotImplementedException();
+ if (!File.Exists(output_graph))
+ return;
+
+ var graph = Graph.ImportFromPB(output_graph, "");
+
+ Tensor input_layer = graph.OperationByName("input/BottleneckInputPlaceholder");
+ Tensor output_layer = graph.OperationByName("final_result");
+
+ with(tf.Session(graph), sess =>
+ {
+ // load images into NDArray in a matrix[image_num, features]
+ var nd = np.arange(2048f).reshape(1, 2048); // replace this line
+ var result = sess.run(output_layer, new FeedItem(input_layer, nd));
+ });
}
- public void Test(Session sess)
+ public void Test(Session sess_)
{
- throw new NotImplementedException();
+ if (!File.Exists(output_graph))
+ return;
+
+ var graph = Graph.ImportFromPB(output_graph);
+ var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();
+
+ with(tf.Session(graph), sess =>
+ {
+ (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists,
+ jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
+ bottleneck_tensor);
+ });
}
}
}