From 21cf2be6603c85c4b1192dc37cdb0f9b4f02b108 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 9 Feb 2020 16:42:35 -0600 Subject: [PATCH] fix freeze_graph output_node_names --- src/TensorFlowNET.Core/APIs/tf.train.cs | 4 ++-- src/TensorFlowNET.Core/TensorFlow.Binding.csproj | 6 +++--- src/TensorFlowNET.Core/Training/Saving/saver.py.cs | 6 ++++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index 8d9957ac..b9bc430d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -56,8 +56,8 @@ namespace Tensorflow public Graph load_graph(string freeze_graph_pb) => saver.load_graph(freeze_graph_pb); - public string freeze_graph(string checkpoint_dir, string output_pb_name) - => saver.freeze_graph(checkpoint_dir, output_pb_name); + public string freeze_graph(string checkpoint_dir, string output_pb_name, string[] output_node_names) + => saver.freeze_graph(checkpoint_dir, output_pb_name, output_node_names); public Saver import_meta_graph(string meta_graph_or_file, bool clear_devices = false, diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index cf3540c2..dd83ae65 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.1 - 0.14.1 + 0.14.1.1 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK true @@ -18,12 +18,12 @@ Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io - 0.14.1.0 + 0.14.1.1 Changes since v0.14.0: 1: Add TransformGraphWithStringInputs. 2: tf.trainer.load_graph, tf.trainer.freeze_graph 7.3 - 0.14.1.0 + 0.14.1.1 LICENSE true true diff --git a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs index 5fa79d1c..5f119791 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs @@ -85,7 +85,9 @@ namespace Tensorflow } } - public static string freeze_graph(string checkpoint_dir, string output_pb_name) + public static string freeze_graph(string checkpoint_dir, + string output_pb_name, + string[] output_node_names) { var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir); if (!File.Exists($"{checkpoint}.meta")) return null; @@ -99,7 +101,7 @@ namespace Tensorflow saver.restore(sess, checkpoint); var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), - new string[] { "output/ArgMax" }); + output_node_names); Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); return output_pb;