|
|
|
@@ -14,9 +14,12 @@ |
|
|
|
limitations under the License. |
|
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
using Google.Protobuf; |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.IO; |
|
|
|
using System.Linq; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
|
{ |
|
|
|
@@ -81,5 +84,35 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public static string freeze_graph(string checkpoint_dir, string output_pb_name) |
|
|
|
{ |
|
|
|
var checkpoint = checkpoint_management.latest_checkpoint(checkpoint_dir); |
|
|
|
if (!File.Exists($"{checkpoint}.meta")) return null; |
|
|
|
|
|
|
|
string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); |
|
|
|
|
|
|
|
using (var graph = tf.Graph()) |
|
|
|
using (var sess = tf.Session(graph)) |
|
|
|
{ |
|
|
|
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); |
|
|
|
saver.restore(sess, checkpoint); |
|
|
|
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, |
|
|
|
graph.as_graph_def(), |
|
|
|
new string[] { "output/ArgMax" }); |
|
|
|
Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); |
|
|
|
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); |
|
|
|
return output_pb; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public static Graph load_graph(string freeze_graph_pb, string name = "") |
|
|
|
{ |
|
|
|
var bytes = File.ReadAllBytes(freeze_graph_pb); |
|
|
|
var graph = tf.Graph().as_default(); |
|
|
|
importer.import_graph_def(GraphDef.Parser.ParseFrom(bytes), |
|
|
|
name: name); |
|
|
|
return graph; |
|
|
|
} |
|
|
|
} |
|
|
|
} |