diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index 1d07747b..da7fb027 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -23,14 +23,14 @@ namespace Tensorflow { public VariableV1[] global_variables(string scope = null) { - return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope)) + return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List) .ToArray(); } public Operation global_variables_initializer() { var g = variables.global_variables(); - return variables.variables_initializer(g?.ToArray()); + return variables.variables_initializer(g.ToArray()); } /// @@ -54,9 +54,9 @@ namespace Tensorflow { var scope = Tensorflow.variable_scope.get_variable_scope(); var store = Tensorflow.variable_scope._get_default_variable_store(); - return scope.get_variable(store, - name, - shape: shape, + return scope.get_variable(store, + name, + shape: shape, dtype: dtype, use_resource: use_resource, validate_shape: validate_shape, diff --git a/src/TensorFlowNET.Core/Binding.FuncTools.cs b/src/TensorFlowNET.Core/Binding.FuncTools.cs index fb038005..71240c67 100644 --- a/src/TensorFlowNET.Core/Binding.FuncTools.cs +++ b/src/TensorFlowNET.Core/Binding.FuncTools.cs @@ -10,11 +10,23 @@ namespace Tensorflow { public static class functools { - public static Func partial(Func func, Tin arg) - => (arg0) => func(arg0); + public static PartialFunc partial(Func func, Tin arg) + => new PartialFunc + { + args = arg, + invoke = func + }; public static Func partial(Func func, (Tin1, Tin2) args) - => (arg1, arg2) => func(arg1, arg2); + => (arg1, arg2) => func(args.Item1, args.Item2); + } + + public class PartialFunc + { + public Tin args { get; set; } + public object[] keywords { get; set; } + + public Func invoke { get; set; } } } } diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 21684001..d80e67b4 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -46,9 +46,9 @@ namespace Tensorflow if (!string.IsNullOrEmpty(unbound_inputs_col_name)) { - foreach(var col in meta_graph_def.CollectionDef) + foreach (var col in meta_graph_def.CollectionDef) { - if(col.Key == unbound_inputs_col_name) + if (col.Key == unbound_inputs_col_name) { throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } @@ -78,7 +78,7 @@ namespace Tensorflow // Restores all the other collections. var variable_objects = new Dictionary(); - foreach(var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) + foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) { // Don't add unbound_inputs to the new graph. if (col.Key == unbound_inputs_col_name) @@ -87,7 +87,7 @@ namespace Tensorflow switch (col.Value.KindCase) { case KindOneofCase.NodeList: - foreach(var value in col.Value.NodeList.Value) + foreach (var value in col.Value.NodeList.Value) { var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); graph.add_to_collection(col.Key, col_op); @@ -115,7 +115,7 @@ namespace Tensorflow } else { - foreach(var value in col.Value.BytesList.Value) + foreach (var value in col.Value.BytesList.Value) { switch (col.Key) { @@ -139,7 +139,7 @@ namespace Tensorflow } } } - + break; default: throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); @@ -173,8 +173,8 @@ namespace Tensorflow string unbound_inputs_col_name = "unbound_inputs", bool clear_devices = false, SaverDef saver_def = null, - bool clear_extraneous_savers= false, - bool strip_default_attrs= false, + bool clear_extraneous_savers = false, + bool strip_default_attrs = false, byte[] meta_info_def = null) { var graph = ops.get_default_graph(); @@ -236,12 +236,12 @@ namespace Tensorflow meta_graph_def.GraphDef = graph_def; // Fills in meta_info_def.stripped_op_list using the ops from graph_def. - if (meta_graph_def.MetaInfoDef.StrippedOpList == null || + if (meta_graph_def.MetaInfoDef.StrippedOpList == null || meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0) meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef); var clist = graph.get_all_collection_keys(); - foreach(var ctype in clist) + foreach (var ctype in clist) { if (clear_extraneous_savers) { @@ -256,30 +256,34 @@ namespace Tensorflow return meta_graph_def; } - private static void add_collection_def(MetaGraphDef meta_graph_def, - string key, + private static void add_collection_def(MetaGraphDef meta_graph_def, + string key, Graph graph = null, string export_scope = "") { if (!meta_graph_def.CollectionDef.ContainsKey(key)) meta_graph_def.CollectionDef[key] = new CollectionDef(); var col_def = meta_graph_def.CollectionDef[key]; - col_def.NodeList = new Types.NodeList(); - col_def.BytesList = new Types.BytesList(); - foreach (object value in graph.get_collection(key)) + + switch (graph.get_collection(key)) { - switch (value) - { - case RefVariable x: + case List collection_list: + col_def.BytesList = new Types.BytesList(); + foreach (var x in collection_list) + { var proto = x.to_proto(export_scope); col_def.BytesList.Value.Add(proto.ToByteString()); - break; - case ITensorOrOperation x2: - col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); - break; - default: - break; - } + } + + break; + case List collection_list: + col_def.NodeList = new Types.NodeList(); + foreach (var x in collection_list) + if (x is ITensorOrOperation x2) + col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); + break; + case List collection_list: + break; } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 767e23f7..a162f54d 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -77,7 +77,7 @@ namespace Tensorflow /// https://www.tensorflow.org/guide/graphs

https://www.tensorflow.org/api_docs/python/tf/Graph
public partial class Graph : DisposableObject #if !SERIALIZABLE - ,IEnumerable + , IEnumerable #endif { private Dictionary _nodes_by_id; @@ -100,15 +100,13 @@ namespace Tensorflow /// private bool _finalized = false; - /// - /// Arbitrary collections of objects inside the graph. - /// TODO: Access might be slow (-> O(n)) depending on size. + /// Arbitrary collections of objects. /// - private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>(); + private Dictionary _collections = new Dictionary(); - public bool building_function; - + public bool building_function; + public Graph() { _handle = c_api.TF_NewGraph(); @@ -230,14 +228,16 @@ namespace Tensorflow throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); } - public void add_to_collection(string name, object value) + public void add_to_collection(string name, T value) { _check_not_finalized(); - _collections.Add((name, null, value)); + if (_collections.ContainsKey(name)) + (_collections[name] as List).Add(value); + else + _collections[name] = new List { value }; } - - public void add_to_collections(List names, object value) + public void add_to_collections(List names, T value) { foreach (string name in names) add_to_collection(name, value); @@ -278,6 +278,12 @@ namespace Tensorflow _create_op_helper(op, true); + /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); + Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); + Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); + Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); + Console.WriteLine();*/ + return op; } @@ -394,7 +400,7 @@ namespace Tensorflow _names_in_use[name_key] = 1; // Return the new name with the original capitalization of the given name. - name = $"{name}_{i-1}"; + name = $"{name}_{i - 1}"; } return name; } @@ -407,8 +413,8 @@ namespace Tensorflow TF_Output[] return_outputs = new TF_Output[num_return_outputs]; unsafe { - var tf_output_ptr = (TF_Output*) return_output_handle; - for (int i = 0; i < num_return_outputs; i++) + var tf_output_ptr = (TF_Output*)return_output_handle; + for (int i = 0; i < num_return_outputs; i++) return_outputs[i] = *(tf_output_ptr + i); return return_outputs; } @@ -416,34 +422,46 @@ namespace Tensorflow public string[] get_all_collection_keys() { - return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray(); + return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); } - public List get_collection(string name, string scope = null) + public object get_collection(string name, string scope = null) { - return get_collection(name, scope); - } - - + return _collections.ContainsKey(name) ? _collections[name] : null; + } + public List get_collection(string name, string scope = null) - { - - return (from c in _collections - where c.name == name && - (scope == null || c.scope == scope) && - implementationOf(c.item) - select (T)(c.item)).ToList(); - - } - - private static bool implementationOf(object item) - { - return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T))); - } - + { + List t = default; + var collection = _collections.ContainsKey(name) ? _collections[name] : new List(); + switch (collection) + { + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + case List list: + t = list.Select(x => (T)(object)x).ToList(); + break; + default: + throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); + } + return t; + } + public List get_collection_ref(string name) { - return get_collection(name); + if (!_collections.ContainsKey(name)) + _collections[name] = new List(); + return _collections[name] as List; } public void prevent_feeding(Tensor tensor) @@ -497,7 +515,7 @@ namespace Tensorflow string debugString = string.Empty; public override string ToString() { - return $"{graph_key}, ({_handle})"; + return $"{graph_key}, ({_handle})"; /*if (string.IsNullOrEmpty(debugString)) { int len = 0; @@ -514,7 +532,7 @@ namespace Tensorflow IEnumerator IEnumerable.GetEnumerator() => GetEnumerable().GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() + IEnumerator IEnumerable.GetEnumerator() => throw new NotImplementedException(); #endif diff --git a/src/TensorFlowNET.Core/Training/TrainingUtil.cs b/src/TensorFlowNET.Core/Training/TrainingUtil.cs index 63227733..9e784550 100644 --- a/src/TensorFlowNET.Core/Training/TrainingUtil.cs +++ b/src/TensorFlowNET.Core/Training/TrainingUtil.cs @@ -16,7 +16,7 @@ namespace Tensorflow.Train // Create in proper graph and base name_scope. var g = graph.as_default(); g.name_scope(null); - var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64, + var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new int[0], dtype: dtypes.int64, initializer: tf.zeros_initializer, trainable: false, aggregation: VariableAggregation.OnlyFirstReplica, diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index fedb1a27..41e8e429 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -43,7 +43,7 @@ namespace Tensorflow protected Graph _graph; bool _building_function; - public variable_scope(string name, + public variable_scope(string name, string default_name = "", Tensor[] values = null, bool? reuse = null, @@ -113,7 +113,7 @@ namespace Tensorflow { // Reenter the current name scope string name_scope = ops.get_name_scope(); - if(!string.IsNullOrEmpty(name_scope)) + if (!string.IsNullOrEmpty(name_scope)) // Hack to reenter name_scope += "/"; current_name_scope = ops.name_scope(name_scope); @@ -128,8 +128,8 @@ namespace Tensorflow string current_name_scope_name = current_name_scope; _current_name_scope = current_name_scope; string old_name_scope = _scope == null ? current_name_scope_name : _scope.original_name_scope; - - if(_scope == null) + + if (_scope == null) pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); else pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); @@ -179,7 +179,7 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, int[] shape = null, bool validate_shape = false, - bool ? use_resource = null, + bool? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) { @@ -189,7 +189,7 @@ namespace Tensorflow use_resource = get_variable_scope().use_resource; } - if(!use_resource.HasValue) + if (!use_resource.HasValue) use_resource = _DEFAULT_USE_RESOURCE; if (use_resource.Value) @@ -204,7 +204,7 @@ namespace Tensorflow } else { - return new RefVariable(initial_value, + return new RefVariable(initial_value, trainable: trainable.Value, validate_shape: validate_shape, collections: collections, @@ -215,13 +215,13 @@ namespace Tensorflow public static _VariableStore _get_default_variable_store() { - var store = ops.get_collection<_VariableStore>(_VARSTORE_KEY).FirstOrDefault(); - if (store == null) - { - store = new _VariableStore(); - ops.add_to_collection(_VARSTORE_KEY, store); - } - return store; + var store = ops.get_collection(_VARSTORE_KEY); + if (store != null) + return (store as List<_VariableStore>)[0]; + + var store1 = new _VariableStore(); + ops.add_to_collection(_VARSTORE_KEY, store1); + return store1; } public static VariableScope get_variable_scope() @@ -231,15 +231,30 @@ namespace Tensorflow public static _VariableScopeStore get_variable_scope_store() { - var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault(); - if (scope_store == null) - scope_store = ops.get_collection(_VARSCOPESTORE_KEY).FirstOrDefault(); + _VariableScopeStore ret = null; + var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); if (scope_store == null) { - scope_store = new _VariableScopeStore(); - ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); + ret = new _VariableScopeStore(); + ops.add_to_collection(_VARSCOPESTORE_KEY, ret); + } + else + { + switch (scope_store) + { + case List values: + ret = values[0]; + break; + case List<_VariableScopeStore> values: + ret = values[0]; + break; + default: + throw new InvalidOperationException("get_variable_scope_store"); + } + } - return scope_store; + + return ret; } public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) @@ -256,7 +271,7 @@ namespace Tensorflow { trainable = true; } - + return trainable.Value; } @@ -279,7 +294,7 @@ namespace Tensorflow } // TODO for Switch/Case - public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, + public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, bool trainable = false, @@ -290,12 +305,12 @@ namespace Tensorflow public void __init__() { - + } public void __del__() { - + } } } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index df43dc63..02417594 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -63,7 +63,7 @@ namespace Tensorflow /// list contains the values in the order under which they were /// collected. /// - public static List get_collection(string key, string scope = null) + public static object get_collection(string key, string scope = null) { return get_default_graph().get_collection(key, scope); }