diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.py.cs b/src/TensorFlowNET.Core/Framework/c_api_util.py.cs new file mode 100644 index 00000000..2ff54c0e --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/c_api_util.py.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class c_api_util + { + public static TF_Output tf_output(IntPtr c_op, int index) => new TF_Output(c_op, index); + + public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); + + public static IntPtr tf_buffer(byte[] data) + { + if (data != null) + throw new NotImplementedException(""); + // var buf = c_api.TF_NewBufferFromString(data); + else + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/importer.py.cs b/src/TensorFlowNET.Core/Framework/importer.py.cs new file mode 100644 index 00000000..2fdc8985 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/importer.py.cs @@ -0,0 +1,158 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.OpDef.Types; + +namespace Tensorflow +{ + public class importer + { + public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, + Dictionary input_map = null, + string[] return_elements = null, + string name = "", + OpList producer_op_list = null) + { + var op_dict = op_def_registry.get_registered_ops(); + + graph_def = _ProcessGraphDefParam(graph_def, op_dict); + input_map = _ProcessInputMapParam(input_map); + return_elements = _ProcessReturnElementsParam(return_elements); + + if (producer_op_list != null) + _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def); + + string prefix = ""; + var graph = ops.get_default_graph(); + Python.with(new ops.name_scope(name, "import", input_map.Values), scope => + { + /*prefix = scope; + if (!string.IsNullOrEmpty(prefix)) + prefix = prefix.Substring(0, prefix.Length - 1); + else + prefix = "";*/ + + // Generate any input map tensors inside name scope + input_map = _ConvertInputMapValues(name, input_map); + }); + + var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); + _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); + + var bytes = graph_def.ToByteString().ToArray(); + + var status = new Status(); + c_api.TF_GraphImportGraphDefWithResults(graph, IntPtr.Zero, scoped_options, status); + + throw new NotImplementedException("importer.import_graph_def"); + } + + public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, + string prefix, + Dictionary input_map, + string[] return_elements) + { + c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); + c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); + + foreach(var input in input_map) + { + throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); + } + + if (return_elements == null) + return_elements = new string[0]; + + foreach (var name in return_elements) + { + throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); + } + } + + public static Dictionary _ConvertInputMapValues(string name, Dictionary input_map) + { + return input_map; + } + + public static GraphDef _ProcessGraphDefParam(GraphDef graph_def, Dictionary op_dict) + { + foreach(var node in graph_def.Node) + { + if (!op_dict.ContainsKey(node.Op)) + continue; + + var op_def = op_dict[node.Op]; + _SetDefaultAttrValues(node, op_def); + } + + return graph_def; + } + + private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) + { + foreach(var attr_def in op_def.Attr) + { + var key = attr_def.Name; + if(attr_def.DefaultValue != null) + { + var value = node_def.Attr[key]; + if (value == null) + node_def.Attr[key] = attr_def.DefaultValue; + } + } + } + + private static Dictionary _ProcessInputMapParam(Dictionary input_map) + { + if (input_map == null) + return new Dictionary(); + + return input_map; + } + + private static string[] _ProcessReturnElementsParam(string[] return_elements) + { + if (return_elements == null) + return null; + + return return_elements; + } + + private static void _RemoveDefaultAttrs(Dictionary op_dict, OpList producer_op_list, GraphDef graph_def) + { + var producer_op_dict = new Dictionary(); + producer_op_list.Op.Select(op => + { + producer_op_dict[op.Name] = op; + return op; + }).ToArray(); + + foreach(var node in graph_def.Node) + { + // Remove any default attr values that aren't in op_def. + if (producer_op_dict.ContainsKey(node.Op)) + { + var op_def = op_dict[node.Op]; + var producer_op_def = producer_op_dict[node.Op]; + foreach(var key in node.Attr) + { + if(_FindAttrInOpDef(key.Key, op_def) == null) + { + var attr_def = _FindAttrInOpDef(key.Key, producer_op_def); + if (attr_def != null && attr_def.DefaultValue != null && + node.Attr[key.Key] == attr_def.DefaultValue) + node.Attr[key.Key].ClearValue(); + } + } + } + } + } + + private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) + { + return op_def.Attr.FirstOrDefault(x => x.Name == name); + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index 8df307d1..6e8d354a 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Text; using static Tensorflow.MetaGraphDef.Types; @@ -8,6 +9,59 @@ namespace Tensorflow { public class meta_graph { + public static MetaGraphDef read_meta_graph_file(string filename) + { + var bytes = File.ReadAllBytes(filename); + var meta_graph_def = MetaGraphDef.Parser.ParseFrom(bytes); + return meta_graph_def; + } + + public static void import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, + bool clear_devices = false, + string import_scope = "", + Dictionary input_map = null, + string unbound_inputs_col_name = "unbound_inputs", + string[] return_elements = null) + { + var meta_graph_def = meta_graph_or_file; + + if (!string.IsNullOrEmpty(unbound_inputs_col_name)) + { + foreach(var col in meta_graph_def.CollectionDef) + { + if(col.Key == unbound_inputs_col_name) + { + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + } + } + } + + // Sets graph to default graph if it's not passed in. + var graph = ops.get_default_graph(); + + // Gathers the list of nodes we are interested in. + OpList producer_op_list = null; + if (meta_graph_def.MetaInfoDef.StrippedOpList != null) + producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList; + var input_graph_def = meta_graph_def.GraphDef; + // Remove all the explicit device specifications for this node. This helps to + // make the graph more portable. + if (clear_devices) + foreach (var node in input_graph_def.Node) + node.Device = ""; + + var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false); + importer.import_graph_def(input_graph_def, + name: scope_to_prepend_to_names, + input_map: input_map, + producer_op_list: producer_op_list, + return_elements: return_elements); + + // Restores all the other collections. + var variable_objects = new Dictionary(); + + } + /// /// Returns `MetaGraphDef` proto. Optionally writes it to filename. /// diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 52c23785..5e15d507 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -218,6 +218,18 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); + /// + /// Set whether to uniquify imported operation names. If true, imported operation + /// names will be modified if their name already exists in the graph. If false, + /// conflicting names will be treated as an error. Note that this option has no + /// effect if a prefix is set, since the prefix will guarantee all names are + /// unique. Defaults to false. + /// + /// TF_ImportGraphDefOptions* + /// unsigned char + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix); + /// /// Fetches the return operations requested via /// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 2d390319..71af7b94 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -4,7 +4,7 @@ netstandard2.0 TensorFlow.NET Tensorflow - 0.1.0 + 0.2.0 Haiping Chen SciSharp STACK true @@ -16,10 +16,13 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET Google's TensorFlow binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.1.0.0 - Implemented the tf.Variable(). -TensorFlow 1.13 RC. + 0.2.0.0 + Added a bunch of APIs. +Fixed String tensor creation bug. +Upgraded to TensorFlow 1.13 RC-1. + 7.2 + 0.2.0.0 diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index f6b6b8af..816ffea7 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -193,6 +193,13 @@ namespace Tensorflow return _is_empty ? string.Empty : model_checkpoint_path; } + public Saver import_meta_graph(string meta_graph_or_file, + bool clear_devices = false, + string import_scope = "") + { + return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope); + } + /// /// Writes `MetaGraphDef` to save_path/filename. /// diff --git a/src/TensorFlowNET.Core/Train/Saving/saver.py.cs b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs new file mode 100644 index 00000000..344e2078 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/Saving/saver.py.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class saver + { + public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file, + bool clear_devices = false, + string import_scope = "", + string[] return_elements = null) + { + var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); + + meta_graph.import_scoped_meta_graph_with_return_elements( + meta_graph_def, + clear_devices: clear_devices, + import_scope: import_scope, + return_elements: return_elements); + + return null; + /*var (imported_vars, imported_return_elements) = ( + , false);*/ + } + } +} diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs index 7d5f1527..8579047a 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs @@ -14,6 +14,12 @@ namespace Tensorflow public static Saver Saver() => new Saver(); public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); + + public static Saver import_meta_graph(string meta_graph_or_file, + bool clear_devices = false, + string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, + clear_devices, + import_scope); } } } diff --git a/src/TensorFlowNET.Core/c_api_util.cs b/src/TensorFlowNET.Core/c_api_util.cs deleted file mode 100644 index 4ee32805..00000000 --- a/src/TensorFlowNET.Core/c_api_util.cs +++ /dev/null @@ -1,14 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow -{ - public class c_api_util - { - public static TF_Output tf_output(IntPtr c_op, int index) - { - return new TF_Output(c_op, index); - } - } -} diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 4602f38e..ce9ebf07 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -7,7 +7,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index f090bdc8..4dfde4fd 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -20,7 +20,7 @@ - + diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index 9e86e8d5..34a5808b 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -20,9 +20,10 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void ImportGraph() { - var v = tf.Variable(0, name: "my_variable"); - var sess = tf.Session(); - tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt"); + with(tf.Session(), sess => + { + var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); + }); } [TestMethod] @@ -45,6 +46,7 @@ namespace TensorFlowNET.UnitTest }); } + [TestMethod] public void Save2() { var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);