| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class c_api_util | |||||
| { | |||||
| public static TF_Output tf_output(IntPtr c_op, int index) => new TF_Output(c_op, index); | |||||
| public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | |||||
| public static IntPtr tf_buffer(byte[] data) | |||||
| { | |||||
| if (data != null) | |||||
| throw new NotImplementedException(""); | |||||
| // var buf = c_api.TF_NewBufferFromString(data); | |||||
| else | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,158 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using static Tensorflow.OpDef.Types; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class importer | |||||
| { | |||||
| public static ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||||
| Dictionary<string, Tensor> input_map = null, | |||||
| string[] return_elements = null, | |||||
| string name = "", | |||||
| OpList producer_op_list = null) | |||||
| { | |||||
| var op_dict = op_def_registry.get_registered_ops(); | |||||
| graph_def = _ProcessGraphDefParam(graph_def, op_dict); | |||||
| input_map = _ProcessInputMapParam(input_map); | |||||
| return_elements = _ProcessReturnElementsParam(return_elements); | |||||
| if (producer_op_list != null) | |||||
| _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def); | |||||
| string prefix = ""; | |||||
| var graph = ops.get_default_graph(); | |||||
| Python.with<ops.name_scope>(new ops.name_scope(name, "import", input_map.Values), scope => | |||||
| { | |||||
| /*prefix = scope; | |||||
| if (!string.IsNullOrEmpty(prefix)) | |||||
| prefix = prefix.Substring(0, prefix.Length - 1); | |||||
| else | |||||
| prefix = "";*/ | |||||
| // Generate any input map tensors inside name scope | |||||
| input_map = _ConvertInputMapValues(name, input_map); | |||||
| }); | |||||
| var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||||
| _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||||
| var bytes = graph_def.ToByteString().ToArray(); | |||||
| var status = new Status(); | |||||
| c_api.TF_GraphImportGraphDefWithResults(graph, IntPtr.Zero, scoped_options, status); | |||||
| throw new NotImplementedException("importer.import_graph_def"); | |||||
| } | |||||
| public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, | |||||
| string prefix, | |||||
| Dictionary<string, Tensor> input_map, | |||||
| string[] return_elements) | |||||
| { | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||||
| c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||||
| foreach(var input in input_map) | |||||
| { | |||||
| throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); | |||||
| } | |||||
| if (return_elements == null) | |||||
| return_elements = new string[0]; | |||||
| foreach (var name in return_elements) | |||||
| { | |||||
| throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); | |||||
| } | |||||
| } | |||||
| public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map) | |||||
| { | |||||
| return input_map; | |||||
| } | |||||
| public static GraphDef _ProcessGraphDefParam(GraphDef graph_def, Dictionary<string, OpDef> op_dict) | |||||
| { | |||||
| foreach(var node in graph_def.Node) | |||||
| { | |||||
| if (!op_dict.ContainsKey(node.Op)) | |||||
| continue; | |||||
| var op_def = op_dict[node.Op]; | |||||
| _SetDefaultAttrValues(node, op_def); | |||||
| } | |||||
| return graph_def; | |||||
| } | |||||
| private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def) | |||||
| { | |||||
| foreach(var attr_def in op_def.Attr) | |||||
| { | |||||
| var key = attr_def.Name; | |||||
| if(attr_def.DefaultValue != null) | |||||
| { | |||||
| var value = node_def.Attr[key]; | |||||
| if (value == null) | |||||
| node_def.Attr[key] = attr_def.DefaultValue; | |||||
| } | |||||
| } | |||||
| } | |||||
| private static Dictionary<string, Tensor> _ProcessInputMapParam(Dictionary<string, Tensor> input_map) | |||||
| { | |||||
| if (input_map == null) | |||||
| return new Dictionary<string, Tensor>(); | |||||
| return input_map; | |||||
| } | |||||
| private static string[] _ProcessReturnElementsParam(string[] return_elements) | |||||
| { | |||||
| if (return_elements == null) | |||||
| return null; | |||||
| return return_elements; | |||||
| } | |||||
| private static void _RemoveDefaultAttrs(Dictionary<string, OpDef> op_dict, OpList producer_op_list, GraphDef graph_def) | |||||
| { | |||||
| var producer_op_dict = new Dictionary<string, OpDef>(); | |||||
| producer_op_list.Op.Select(op => | |||||
| { | |||||
| producer_op_dict[op.Name] = op; | |||||
| return op; | |||||
| }).ToArray(); | |||||
| foreach(var node in graph_def.Node) | |||||
| { | |||||
| // Remove any default attr values that aren't in op_def. | |||||
| if (producer_op_dict.ContainsKey(node.Op)) | |||||
| { | |||||
| var op_def = op_dict[node.Op]; | |||||
| var producer_op_def = producer_op_dict[node.Op]; | |||||
| foreach(var key in node.Attr) | |||||
| { | |||||
| if(_FindAttrInOpDef(key.Key, op_def) == null) | |||||
| { | |||||
| var attr_def = _FindAttrInOpDef(key.Key, producer_op_def); | |||||
| if (attr_def != null && attr_def.DefaultValue != null && | |||||
| node.Attr[key.Key] == attr_def.DefaultValue) | |||||
| node.Attr[key.Key].ClearValue(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private static AttrDef _FindAttrInOpDef(string name, OpDef op_def) | |||||
| { | |||||
| return op_def.Attr.FirstOrDefault(x => x.Name == name); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.MetaGraphDef.Types; | using static Tensorflow.MetaGraphDef.Types; | ||||
| @@ -8,6 +9,59 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class meta_graph | public class meta_graph | ||||
| { | { | ||||
| public static MetaGraphDef read_meta_graph_file(string filename) | |||||
| { | |||||
| var bytes = File.ReadAllBytes(filename); | |||||
| var meta_graph_def = MetaGraphDef.Parser.ParseFrom(bytes); | |||||
| return meta_graph_def; | |||||
| } | |||||
| public static void import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, | |||||
| bool clear_devices = false, | |||||
| string import_scope = "", | |||||
| Dictionary<string, Tensor> input_map = null, | |||||
| string unbound_inputs_col_name = "unbound_inputs", | |||||
| string[] return_elements = null) | |||||
| { | |||||
| var meta_graph_def = meta_graph_or_file; | |||||
| if (!string.IsNullOrEmpty(unbound_inputs_col_name)) | |||||
| { | |||||
| foreach(var col in meta_graph_def.CollectionDef) | |||||
| { | |||||
| if(col.Key == unbound_inputs_col_name) | |||||
| { | |||||
| throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Sets graph to default graph if it's not passed in. | |||||
| var graph = ops.get_default_graph(); | |||||
| // Gathers the list of nodes we are interested in. | |||||
| OpList producer_op_list = null; | |||||
| if (meta_graph_def.MetaInfoDef.StrippedOpList != null) | |||||
| producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList; | |||||
| var input_graph_def = meta_graph_def.GraphDef; | |||||
| // Remove all the explicit device specifications for this node. This helps to | |||||
| // make the graph more portable. | |||||
| if (clear_devices) | |||||
| foreach (var node in input_graph_def.Node) | |||||
| node.Device = ""; | |||||
| var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false); | |||||
| importer.import_graph_def(input_graph_def, | |||||
| name: scope_to_prepend_to_names, | |||||
| input_map: input_map, | |||||
| producer_op_list: producer_op_list, | |||||
| return_elements: return_elements); | |||||
| // Restores all the other collections. | |||||
| var variable_objects = new Dictionary<string, RefVariable>(); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns `MetaGraphDef` proto. Optionally writes it to filename. | /// Returns `MetaGraphDef` proto. Optionally writes it to filename. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -218,6 +218,18 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); | public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); | ||||
| /// <summary> | |||||
| /// Set whether to uniquify imported operation names. If true, imported operation | |||||
| /// names will be modified if their name already exists in the graph. If false, | |||||
| /// conflicting names will be treated as an error. Note that this option has no | |||||
| /// effect if a prefix is set, since the prefix will guarantee all names are | |||||
| /// unique. Defaults to false. | |||||
| /// </summary> | |||||
| /// <param name="ops">TF_ImportGraphDefOptions*</param> | |||||
| /// <param name="uniquify_prefix">unsigned char</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefOptionsSetUniquifyNames(IntPtr ops, char uniquify_prefix); | |||||
| /// <summary> | /// <summary> | ||||
| /// Fetches the return operations requested via | /// Fetches the return operations requested via | ||||
| /// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched | /// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched | ||||
| @@ -4,7 +4,7 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <Version>0.1.0</Version> | |||||
| <Version>0.2.0</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -16,10 +16,13 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | ||||
| <Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.1.0.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Implemented the tf.Variable(). | |||||
| TensorFlow 1.13 RC.</PackageReleaseNotes> | |||||
| <AssemblyVersion>0.2.0.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Added a bunch of APIs. | |||||
| Fixed String tensor creation bug. | |||||
| Upgraded to TensorFlow 1.13 RC-1. | |||||
| </PackageReleaseNotes> | |||||
| <LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
| <FileVersion>0.2.0.0</FileVersion> | |||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
| @@ -193,6 +193,13 @@ namespace Tensorflow | |||||
| return _is_empty ? string.Empty : model_checkpoint_path; | return _is_empty ? string.Empty : model_checkpoint_path; | ||||
| } | } | ||||
| public Saver import_meta_graph(string meta_graph_or_file, | |||||
| bool clear_devices = false, | |||||
| string import_scope = "") | |||||
| { | |||||
| return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Writes `MetaGraphDef` to save_path/filename. | /// Writes `MetaGraphDef` to save_path/filename. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,27 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class saver | |||||
| { | |||||
| public static Saver _import_meta_graph_with_return_elements(string meta_graph_or_file, | |||||
| bool clear_devices = false, | |||||
| string import_scope = "", | |||||
| string[] return_elements = null) | |||||
| { | |||||
| var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file); | |||||
| meta_graph.import_scoped_meta_graph_with_return_elements( | |||||
| meta_graph_def, | |||||
| clear_devices: clear_devices, | |||||
| import_scope: import_scope, | |||||
| return_elements: return_elements); | |||||
| return null; | |||||
| /*var (imported_vars, imported_return_elements) = ( | |||||
| , false);*/ | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -14,6 +14,12 @@ namespace Tensorflow | |||||
| public static Saver Saver() => new Saver(); | public static Saver Saver() => new Saver(); | ||||
| public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); | public static string write_graph(Graph graph, string logdir, string name, bool as_text = true) => graph_io.write_graph(graph, logdir, name, as_text); | ||||
| public static Saver import_meta_graph(string meta_graph_or_file, | |||||
| bool clear_devices = false, | |||||
| string import_scope = "") => saver._import_meta_graph_with_return_elements(meta_graph_or_file, | |||||
| clear_devices, | |||||
| import_scope); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,14 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class c_api_util | |||||
| { | |||||
| public static TF_Output tf_output(IntPtr c_op, int index) | |||||
| { | |||||
| return new TF_Output(c_op, index); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -7,7 +7,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="NumSharp" Version="0.7.1" /> | <PackageReference Include="NumSharp" Version="0.7.1" /> | ||||
| <PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.2.0" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -20,7 +20,7 @@ | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
| <PackageReference Include="NumSharp" Version="0.7.1" /> | <PackageReference Include="NumSharp" Version="0.7.1" /> | ||||
| <PackageReference Include="TensorFlow.NET" Version="0.1.0" /> | |||||
| <PackageReference Include="TensorFlow.NET" Version="0.2.0" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -20,9 +20,10 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void ImportGraph() | public void ImportGraph() | ||||
| { | { | ||||
| var v = tf.Variable(0, name: "my_variable"); | |||||
| var sess = tf.Session(); | |||||
| tf.train.write_graph(sess.graph, "/tmp/my-model", "train2.pbtxt"); | |||||
| with<Session>(tf.Session(), sess => | |||||
| { | |||||
| var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); | |||||
| }); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -45,6 +46,7 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Save2() | public void Save2() | ||||
| { | { | ||||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | ||||