diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index e6e60e32..402ffca1 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -24,7 +24,7 @@ namespace Tensorflow return ops.get_default_graph(); } - public static Graph Graph() => new Graph(); - + public static Graph Graph() + => new Graph(); } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index 958106ee..7fcfdbd7 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -55,11 +55,11 @@ namespace Tensorflow return Status; } - public static Graph ImportFromPB(string file_path) + public static Graph ImportFromPB(string file_path, string name = null) { var graph = tf.Graph().as_default(); var graph_def = GraphDef.Parser.ParseFrom(File.ReadAllBytes(file_path)); - importer.import_graph_def(graph_def); + importer.import_graph_def(graph_def, name: name); return graph; } } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 4b7d641c..1ec4f6f3 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.10.2 + 0.10.3 Haiping Chen, Meinrad Recheis SciSharp STACK true @@ -16,9 +16,8 @@ https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. -Docs: https://tensorflownet.readthedocs.io -Medium: https://medium.com/scisharp - 0.10.2.0 +Docs: https://tensorflownet.readthedocs.io + 0.10.3.0 Changes since v0.9.0: 1. Added full connected Convolution Neural Network example. @@ -29,11 +28,12 @@ Medium: https://medium.com/scisharp 6. Add StridedSliceGrad. 7. Add BatchMatMulGrad. 8. Upgrade NumSharp. -9. Fix strided_slice_grad type convention error. +9. Fix strided_slice_grad type convention error. +10. Add AbsGrad. 7.2 - 0.10.2.0 + 0.10.3.0 LICENSE - false + true true Open.snk diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 628a7b5c..78f23155 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -80,6 +80,16 @@ namespace TensorFlowNET.Examples.ImageProcess { PrepareData(); + #region For debug purpose + + // predict images + Predict(null); + + // load saved pb and test new images. + Test(null); + + #endregion + var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); with(tf.Session(graph), sess => @@ -708,14 +718,38 @@ namespace TensorFlowNET.Examples.ImageProcess File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys)); } - public void Predict(Session sess) + public void Predict(Session sess_) { - throw new NotImplementedException(); + if (!File.Exists(output_graph)) + return; + + var graph = Graph.ImportFromPB(output_graph, ""); + + Tensor input_layer = graph.OperationByName("input/BottleneckInputPlaceholder"); + Tensor output_layer = graph.OperationByName("final_result"); + + with(tf.Session(graph), sess => + { + // load images into NDArray in a matrix[image_num, features] + var nd = np.arange(2048f).reshape(1, 2048); // replace this line + var result = sess.run(output_layer, new FeedItem(input_layer, nd)); + }); } - public void Test(Session sess) + public void Test(Session sess_) { - throw new NotImplementedException(); + if (!File.Exists(output_graph)) + return; + + var graph = Graph.ImportFromPB(output_graph); + var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); + + with(tf.Session(graph), sess => + { + (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists, + jpeg_data_tensor, decoded_image_tensor, resized_image_tensor, + bottleneck_tensor); + }); } } }