From 170f3a774e8b16fc3fa214d2baea120ba0399bdc Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 9 Feb 2020 12:42:20 -0600 Subject: [PATCH] tf.trainer.load_graph, tf.trainer.freeze_graph --- src/TensorFlowNET.Core/APIs/tf.train.cs | 6 ++++ .../TensorFlow.Binding.csproj | 9 ++--- .../Training/Saving/saver.py.cs | 33 +++++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index 862212ef..8d9957ac 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -53,6 +53,12 @@ namespace Tensorflow public string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); + public Graph load_graph(string freeze_graph_pb) + => saver.load_graph(freeze_graph_pb); + + public string freeze_graph(string checkpoint_dir, string output_pb_name) + => saver.freeze_graph(checkpoint_dir, output_pb_name); + public Saver import_meta_graph(string meta_graph_or_file, bool clear_devices = false, string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 031d2b10..cf3540c2 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.1 - 0.15.0 + 0.14.1 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -18,11 +18,12 @@ Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io - 0.15.0.0 + 0.14.1.0 Changes since v0.14.0: -1: Add TransformGraphWithStringInputs. +1: Add TransformGraphWithStringInputs. +2: tf.trainer.load_graph, tf.trainer.freeze_graph 7.3 - 0.15.0.0 + 0.14.1.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs index 2b75947b..5fa79d1c 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs @@ -14,9 +14,12 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; using System; using System.Collections.Generic; +using System.IO; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow { @@ -81,5 +84,35 @@ namespace Tensorflow } } } + + public static string freeze_graph(string checkpoint_dir, string output_pb_name) + { + var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir); + if (!File.Exists($"{checkpoint}.meta")) return null; + + string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); + + using (var graph = tf.Graph()) + using (var sess = tf.Session(graph)) + { + var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); + saver.restore(sess, checkpoint); + var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, + graph.as_graph_def(), + new string[] { "output/ArgMax" }); + Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); + File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); + return output_pb; + } + } + + public static Graph load_graph(string freeze_graph_pb, string name = "") + { + var bytes = File.ReadAllBytes(freeze_graph_pb); + var graph = tf.Graph().as_default(); + importer.import_graph_def(GraphDef.Parser.ParseFrom(bytes), + name: name); + return graph; + } } }