From b87081cc4ac47dd3ad28f40cd899a228c0fc3552 Mon Sep 17 00:00:00 2001 From: "Mascha, Philipp" Date: Thu, 14 Nov 2019 14:07:59 +0100 Subject: [PATCH] Made internal collection of graph a list. --- src/TensorFlowNET.Core/APIs/tf.variable.cs | 4 +- .../Framework/meta_graph.cs | 30 ++++----- src/TensorFlowNET.Core/Graphs/Graph.cs | 67 ++++++------------- src/TensorFlowNET.Core/Summaries/Summary.cs | 7 +- .../Variables/variable_scope.py.cs | 43 +++++------- .../Variables/variables.py.cs | 12 +--- src/TensorFlowNET.Core/ops.cs | 2 +- .../Keras/EmbeddingTest.cs | 1 - test/TensorFlowNET.UnitTest/VariableTest.cs | 5 +- 9 files changed, 61 insertions(+), 110 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index 179cedee..1d07747b 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -23,14 +23,14 @@ namespace Tensorflow { public VariableV1[] global_variables(string scope = null) { - return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List) + return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope)) .ToArray(); } public Operation global_variables_initializer() { var g = variables.global_variables(); - return variables.variables_initializer(g.ToArray()); + return variables.variables_initializer(g?.ToArray()); } /// diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 05092581..21684001 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -264,26 +264,22 @@ namespace Tensorflow if (!meta_graph_def.CollectionDef.ContainsKey(key)) meta_graph_def.CollectionDef[key] = new CollectionDef(); var col_def = meta_graph_def.CollectionDef[key]; - - switch (graph.get_collection(key)) + col_def.NodeList = new Types.NodeList(); + col_def.BytesList = new Types.BytesList(); + foreach (object value in graph.get_collection(key)) { - case List collection_list: - col_def.BytesList = new Types.BytesList(); - foreach (var x in collection_list) - { + switch (value) + { + case RefVariable x: var proto = x.to_proto(export_scope); col_def.BytesList.Value.Add(proto.ToByteString()); - } - - break; - case List collection_list: - col_def.NodeList = new Types.NodeList(); - foreach (var x in collection_list) - if (x is ITensorOrOperation x2) - col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); - break; - case List collection_list: - break; + break; + case ITensorOrOperation x2: + col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); + break; + default: + break; + } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 7ed19439..29514cd0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -100,10 +100,12 @@ namespace Tensorflow /// private bool _finalized = false; + /// - /// Arbitrary collections of objects. + /// Arbitrary collections of objects inside the graph. + /// TODO: Access might be slow (-> O(n)) depending on size. /// - private Dictionary _collections = new Dictionary(); + private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>(); public bool building_function; @@ -228,16 +230,14 @@ namespace Tensorflow throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); } - public void add_to_collection(string name, T value) + public void add_to_collection(string name, object value) { _check_not_finalized(); - if (_collections.ContainsKey(name)) - (_collections[name] as List).Add(value); - else - _collections[name] = new List { value }; + _collections.Add((name, null, value)); } - public void add_to_collections(List names, T value) + + public void add_to_collections(List names, object value) { foreach (string name in names) add_to_collection(name, value); @@ -278,12 +278,6 @@ namespace Tensorflow _create_op_helper(op, true); - /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); - Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); - Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); - Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); - Console.WriteLine();*/ - return op; } @@ -422,46 +416,29 @@ namespace Tensorflow public string[] get_all_collection_keys() { - return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); + return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray(); } - public object get_collection(string name, string scope = null) + public List get_collection(string name, string scope = null) { - return _collections.ContainsKey(name) ? _collections[name] : null; - } - + return get_collection(name, scope); + } + + private IEnumerable findObjects(string name, string scope) + { + return (from c in _collections where c.name == name && (scope == null || c.scope == scope) select c.item); + } + public List get_collection(string name, string scope = null) { - List t = default; - var collection = _collections.ContainsKey(name) ? _collections[name] : new List(); - switch (collection) - { - case List list: - t = list.Select(x => (T)(object)x).ToList(); - break; - case List list: - t = list.Select(x => (T)(object)x).ToList(); - break; - case List list: - t = list.Select(x => (T)(object)x).ToList(); - break; - case List list: - t = list.Select(x => (T)(object)x).ToList(); - break; - case List list: - t = list.Select(x => (T)(object)x).ToList(); - break; - default: - throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); - } - return t; + + return (from c in findObjects(name, scope) where c.GetType().IsSubclassOf(typeof(T)) select (T)c).ToList(); + } public List get_collection_ref(string name) { - if (!_collections.ContainsKey(name)) - _collections[name] = new List(); - return _collections[name] as List; + return get_collection(name); } public void prevent_feeding(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Summaries/Summary.cs b/src/TensorFlowNET.Core/Summaries/Summary.cs index 3d157bd9..84889845 100644 --- a/src/TensorFlowNET.Core/Summaries/Summary.cs +++ b/src/TensorFlowNET.Core/Summaries/Summary.cs @@ -39,11 +39,8 @@ namespace Tensorflow.Summaries public Tensor merge_all(string key = "summaries", string scope= null, string name= null) { - var summary_ops = ops.get_collection(key, scope: scope); - if (summary_ops == null) - return null; - else - return merge((summary_ops as List).Select(x => x as Tensor).ToArray(), name: name); + var summary_ops = ops.get_collection(key, scope: scope); + return merge(summary_ops.Select(x => x as Tensor).ToArray(), name: name); } /// diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index f4a01054..b43663c0 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -215,13 +215,13 @@ namespace Tensorflow public static _VariableStore _get_default_variable_store() { - var store = ops.get_collection(_VARSTORE_KEY); - if (store != null) - return (store as List<_VariableStore>)[0]; - - var store1 = new _VariableStore(); - ops.add_to_collection(_VARSTORE_KEY, store1); - return store1; + var store = ops.get_collection<_VariableStore>(_VARSTORE_KEY).FirstOrDefault(); + if (store == null) + { + store = new _VariableStore(); + ops.add_to_collection(_VARSTORE_KEY, store); + } + return store; } public static VariableScope get_variable_scope() @@ -229,32 +229,19 @@ namespace Tensorflow return get_variable_scope_store().current_scope; } + + // TODO: Misses RefVariable as possible value type? public static _VariableScopeStore get_variable_scope_store() { - _VariableScopeStore ret = null; - var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); + var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault(); + if (scope_store == null) + scope_store = ops.get_collection(_VARSCOPESTORE_KEY).FirstOrDefault(); if (scope_store == null) { - ret = new _VariableScopeStore(); - ops.add_to_collection(_VARSCOPESTORE_KEY, ret); - } - else - { - switch (scope_store) - { - case List values: - ret = values[0]; - break; - case List<_VariableScopeStore> values: - ret = values[0]; - break; - default: - throw new InvalidOperationException("get_variable_scope_store"); - } - + scope_store = new _VariableScopeStore(); + ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); } - - return ret; + return scope_store; } public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 0e056949..818b324e 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -41,13 +41,8 @@ namespace Tensorflow { var all = new List(); - var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); - if(collection != null) - all.AddRange(collection as List); - - collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); - if (collection != null) - all.AddRange(collection as List); + all.AddRange(ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope)); + all.AddRange(ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope)); return all.ToArray(); } @@ -65,9 +60,8 @@ namespace Tensorflow /// A list of `Variable` objects. public static List global_variables(string scope = null) { - var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); + return ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); - return result == null ? new List() : result as List; } /// diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 02417594..df43dc63 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -63,7 +63,7 @@ namespace Tensorflow /// list contains the values in the order under which they were /// collected. /// - public static object get_collection(string key, string scope = null) + public static List get_collection(string key, string scope = null) { return get_default_graph().get_collection(key, scope); } diff --git a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs b/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs index d3484d5e..0168f22c 100644 --- a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs @@ -14,7 +14,6 @@ namespace TensorFlowNET.UnitTest.Keras [TestClass] public class EmbeddingTest { - [Ignore] [TestMethod] public void Embedding() { diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index e1a91560..5fb65d66 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -1,4 +1,5 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; using NumSharp; using Tensorflow; using static Tensorflow.Binding; @@ -41,7 +42,7 @@ namespace TensorFlowNET.UnitTest tf_with(tf.variable_scope("bar"), delegate { var v = tf.get_variable("v", new TensorShape(1)); - Assert.AreEqual(v.name, "foo/bar/v:0"); + v.name.Should().Be("foo/bar/v:0"); }); }); }