| @@ -23,14 +23,14 @@ namespace Tensorflow | |||||
| { | { | ||||
| public VariableV1[] global_variables(string scope = null) | public VariableV1[] global_variables(string scope = null) | ||||
| { | { | ||||
| return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>) | |||||
| return (ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope)) | |||||
| .ToArray(); | .ToArray(); | ||||
| } | } | ||||
| public Operation global_variables_initializer() | public Operation global_variables_initializer() | ||||
| { | { | ||||
| var g = variables.global_variables(); | var g = variables.global_variables(); | ||||
| return variables.variables_initializer(g.ToArray()); | |||||
| return variables.variables_initializer(g?.ToArray()); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -264,26 +264,22 @@ namespace Tensorflow | |||||
| if (!meta_graph_def.CollectionDef.ContainsKey(key)) | if (!meta_graph_def.CollectionDef.ContainsKey(key)) | ||||
| meta_graph_def.CollectionDef[key] = new CollectionDef(); | meta_graph_def.CollectionDef[key] = new CollectionDef(); | ||||
| var col_def = meta_graph_def.CollectionDef[key]; | var col_def = meta_graph_def.CollectionDef[key]; | ||||
| switch (graph.get_collection(key)) | |||||
| col_def.NodeList = new Types.NodeList(); | |||||
| col_def.BytesList = new Types.BytesList(); | |||||
| foreach (object value in graph.get_collection(key)) | |||||
| { | { | ||||
| case List<RefVariable> collection_list: | |||||
| col_def.BytesList = new Types.BytesList(); | |||||
| foreach (var x in collection_list) | |||||
| { | |||||
| switch (value) | |||||
| { | |||||
| case RefVariable x: | |||||
| var proto = x.to_proto(export_scope); | var proto = x.to_proto(export_scope); | ||||
| col_def.BytesList.Value.Add(proto.ToByteString()); | col_def.BytesList.Value.Add(proto.ToByteString()); | ||||
| } | |||||
| break; | |||||
| case List<object> collection_list: | |||||
| col_def.NodeList = new Types.NodeList(); | |||||
| foreach (var x in collection_list) | |||||
| if (x is ITensorOrOperation x2) | |||||
| col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); | |||||
| break; | |||||
| case List<Operation> collection_list: | |||||
| break; | |||||
| break; | |||||
| case ITensorOrOperation x2: | |||||
| col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope)); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -100,10 +100,12 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| private bool _finalized = false; | private bool _finalized = false; | ||||
| /// <summary> | /// <summary> | ||||
| /// Arbitrary collections of objects. | |||||
| /// Arbitrary collections of objects inside the graph. | |||||
| /// TODO: Access might be slow (-> O(n)) depending on size. | |||||
| /// </summary> | /// </summary> | ||||
| private Dictionary<string, object> _collections = new Dictionary<string, object>(); | |||||
| private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>(); | |||||
| public bool building_function; | public bool building_function; | ||||
| @@ -228,16 +230,14 @@ namespace Tensorflow | |||||
| throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); | throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}."); | ||||
| } | } | ||||
| public void add_to_collection<T>(string name, T value) | |||||
| public void add_to_collection(string name, object value) | |||||
| { | { | ||||
| _check_not_finalized(); | _check_not_finalized(); | ||||
| if (_collections.ContainsKey(name)) | |||||
| (_collections[name] as List<T>).Add(value); | |||||
| else | |||||
| _collections[name] = new List<T> { value }; | |||||
| _collections.Add((name, null, value)); | |||||
| } | } | ||||
| public void add_to_collections<T>(List<string> names, T value) | |||||
| public void add_to_collections(List<string> names, object value) | |||||
| { | { | ||||
| foreach (string name in names) | foreach (string name in names) | ||||
| add_to_collection(name, value); | add_to_collection(name, value); | ||||
| @@ -278,12 +278,6 @@ namespace Tensorflow | |||||
| _create_op_helper(op, true); | _create_op_helper(op, true); | ||||
| /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | |||||
| Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | |||||
| Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); | |||||
| Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); | |||||
| Console.WriteLine();*/ | |||||
| return op; | return op; | ||||
| } | } | ||||
| @@ -422,46 +416,34 @@ namespace Tensorflow | |||||
| public string[] get_all_collection_keys() | public string[] get_all_collection_keys() | ||||
| { | { | ||||
| return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | |||||
| return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray(); | |||||
| } | } | ||||
| public object get_collection(string name, string scope = null) | |||||
| public List<object> get_collection(string name, string scope = null) | |||||
| { | { | ||||
| return _collections.ContainsKey(name) ? _collections[name] : null; | |||||
| } | |||||
| return get_collection<object>(name, scope); | |||||
| } | |||||
| public List<T> get_collection<T>(string name, string scope = null) | public List<T> get_collection<T>(string name, string scope = null) | ||||
| { | |||||
| List<T> t = default; | |||||
| var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>(); | |||||
| switch (collection) | |||||
| { | |||||
| case List<VariableV1> list: | |||||
| t = list.Select(x => (T)(object)x).ToList(); | |||||
| break; | |||||
| case List<ResourceVariable> list: | |||||
| t = list.Select(x => (T)(object)x).ToList(); | |||||
| break; | |||||
| case List<RefVariable> list: | |||||
| t = list.Select(x => (T)(object)x).ToList(); | |||||
| break; | |||||
| case List<Tensor> list: | |||||
| t = list.Select(x => (T)(object)x).ToList(); | |||||
| break; | |||||
| case List<Operation> list: | |||||
| t = list.Select(x => (T)(object)x).ToList(); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException($"get_collection<{typeof(T).FullName}>"); | |||||
| } | |||||
| return t; | |||||
| } | |||||
| { | |||||
| return (from c in _collections | |||||
| where c.name == name && | |||||
| (scope == null || c.scope == scope) && | |||||
| implementationOf<T>(c.item) | |||||
| select (T)(c.item)).ToList(); | |||||
| } | |||||
| private static bool implementationOf<T>(object item) | |||||
| { | |||||
| return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T))); | |||||
| } | |||||
| public List<T> get_collection_ref<T>(string name) | public List<T> get_collection_ref<T>(string name) | ||||
| { | { | ||||
| if (!_collections.ContainsKey(name)) | |||||
| _collections[name] = new List<T>(); | |||||
| return _collections[name] as List<T>; | |||||
| return get_collection<T>(name); | |||||
| } | } | ||||
| public void prevent_feeding(Tensor tensor) | public void prevent_feeding(Tensor tensor) | ||||
| @@ -39,11 +39,8 @@ namespace Tensorflow.Summaries | |||||
| public Tensor merge_all(string key = "summaries", string scope= null, string name= null) | public Tensor merge_all(string key = "summaries", string scope= null, string name= null) | ||||
| { | { | ||||
| var summary_ops = ops.get_collection(key, scope: scope); | |||||
| if (summary_ops == null) | |||||
| return null; | |||||
| else | |||||
| return merge((summary_ops as List<ITensorOrOperation>).Select(x => x as Tensor).ToArray(), name: name); | |||||
| var summary_ops = ops.get_collection<ITensorOrOperation>(key, scope: scope); | |||||
| return merge(summary_ops.Select(x => x as Tensor).ToArray(), name: name); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -215,13 +215,13 @@ namespace Tensorflow | |||||
| public static _VariableStore _get_default_variable_store() | public static _VariableStore _get_default_variable_store() | ||||
| { | { | ||||
| var store = ops.get_collection(_VARSTORE_KEY); | |||||
| if (store != null) | |||||
| return (store as List<_VariableStore>)[0]; | |||||
| var store1 = new _VariableStore(); | |||||
| ops.add_to_collection(_VARSTORE_KEY, store1); | |||||
| return store1; | |||||
| var store = ops.get_collection<_VariableStore>(_VARSTORE_KEY).FirstOrDefault(); | |||||
| if (store == null) | |||||
| { | |||||
| store = new _VariableStore(); | |||||
| ops.add_to_collection(_VARSTORE_KEY, store); | |||||
| } | |||||
| return store; | |||||
| } | } | ||||
| public static VariableScope get_variable_scope() | public static VariableScope get_variable_scope() | ||||
| @@ -231,30 +231,15 @@ namespace Tensorflow | |||||
| public static _VariableScopeStore get_variable_scope_store() | public static _VariableScopeStore get_variable_scope_store() | ||||
| { | { | ||||
| _VariableScopeStore ret = null; | |||||
| var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); | |||||
| var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault(); | |||||
| if (scope_store == null) | |||||
| scope_store = ops.get_collection<RefVariable>(_VARSCOPESTORE_KEY).FirstOrDefault(); | |||||
| if (scope_store == null) | if (scope_store == null) | ||||
| { | { | ||||
| ret = new _VariableScopeStore(); | |||||
| ops.add_to_collection(_VARSCOPESTORE_KEY, ret); | |||||
| } | |||||
| else | |||||
| { | |||||
| switch (scope_store) | |||||
| { | |||||
| case List<RefVariable> values: | |||||
| ret = values[0]; | |||||
| break; | |||||
| case List<_VariableScopeStore> values: | |||||
| ret = values[0]; | |||||
| break; | |||||
| default: | |||||
| throw new InvalidOperationException("get_variable_scope_store"); | |||||
| } | |||||
| scope_store = new _VariableScopeStore(); | |||||
| ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); | |||||
| } | } | ||||
| return ret; | |||||
| return scope_store; | |||||
| } | } | ||||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) | public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) | ||||
| @@ -41,13 +41,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| var all = new List<VariableV1>(); | var all = new List<VariableV1>(); | ||||
| var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| if(collection != null) | |||||
| all.AddRange(collection as List<VariableV1>); | |||||
| collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope); | |||||
| if (collection != null) | |||||
| all.AddRange(collection as List<VariableV1>); | |||||
| all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope)); | |||||
| all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.SAVEABLE_OBJECTS, scope)); | |||||
| return all.ToArray(); | return all.ToArray(); | ||||
| } | } | ||||
| @@ -65,9 +60,8 @@ namespace Tensorflow | |||||
| /// <returns>A list of `Variable` objects.</returns> | /// <returns>A list of `Variable` objects.</returns> | ||||
| public static List<VariableV1> global_variables(string scope = null) | public static List<VariableV1> global_variables(string scope = null) | ||||
| { | { | ||||
| var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| return ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| return result == null ? new List<VariableV1>() : result as List<VariableV1>; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -63,7 +63,7 @@ namespace Tensorflow | |||||
| /// list contains the values in the order under which they were | /// list contains the values in the order under which they were | ||||
| /// collected. | /// collected. | ||||
| /// </returns> | /// </returns> | ||||
| public static object get_collection(string key, string scope = null) | |||||
| public static List<object> get_collection(string key, string scope = null) | |||||
| { | { | ||||
| return get_default_graph().get_collection(key, scope); | return get_default_graph().get_collection(key, scope); | ||||
| } | } | ||||
| @@ -14,7 +14,6 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
| [TestClass] | [TestClass] | ||||
| public class EmbeddingTest | public class EmbeddingTest | ||||
| { | { | ||||
| [Ignore] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Embedding() | public void Embedding() | ||||
| { | { | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using FluentAssertions; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -41,7 +42,7 @@ namespace TensorFlowNET.UnitTest | |||||
| tf_with(tf.variable_scope("bar"), delegate | tf_with(tf.variable_scope("bar"), delegate | ||||
| { | { | ||||
| var v = tf.get_variable("v", new TensorShape(1)); | var v = tf.get_variable("v", new TensorShape(1)); | ||||
| Assert.AreEqual(v.name, "foo/bar/v:0"); | |||||
| v.name.Should().Be("foo/bar/v:0"); | |||||
| }); | }); | ||||
| }); | }); | ||||
| } | } | ||||