| @@ -17,5 +17,18 @@ namespace Tensorflow | |||||
| return buffer; | return buffer; | ||||
| } | } | ||||
| public GraphDef _as_graph_def() | |||||
| { | |||||
| var buffer = ToGraphDef(Status); | |||||
| Status.Check(); | |||||
| var def = GraphDef.Parser.ParseFrom(buffer); | |||||
| buffer.Dispose(); | |||||
| // Strip the experimental library field iff it's empty. | |||||
| // if(def.Library.Function.Count == 0) | |||||
| return def; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ namespace Tensorflow | |||||
| return null; | return null; | ||||
| } | } | ||||
| private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true) | ||||
| { | { | ||||
| string types_str = ""; | string types_str = ""; | ||||
| @@ -0,0 +1,21 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class graph_io | |||||
| { | |||||
| public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) | |||||
| { | |||||
| var def = graph._as_graph_def(); | |||||
| string path = Path.Combine(logdir, name); | |||||
| string text = def.ToString(); | |||||
| if (as_text) | |||||
| File.WriteAllText(path, text); | |||||
| return path; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -11,6 +12,8 @@ namespace Tensorflow | |||||
| public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate); | ||||
| public static Saver Saver() => new Saver(); | public static Saver Saver() => new Saver(); | ||||
| public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -9,6 +9,14 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class TrainSaverTest : Python | public class TrainSaverTest : Python | ||||
| { | { | ||||
| [TestMethod] | |||||
| public void WriteGraph() | |||||
| { | |||||
| var v = tf.Variable(0, name: "my_variable"); | |||||
| var sess = tf.Session(); | |||||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train.pbtxt"); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Save() | public void Save() | ||||
| { | { | ||||