diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 29514cd0..767e23f7 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -424,18 +424,23 @@ namespace Tensorflow return get_collection(name, scope); } - private IEnumerable findObjects(string name, string scope) + + public List get_collection(string name, string scope = null) { - return (from c in _collections where c.name == name && (scope == null || c.scope == scope) select c.item); + + return (from c in _collections + where c.name == name && + (scope == null || c.scope == scope) && + implementationOf(c.item) + select (T)(c.item)).ToList(); + + } + + private static bool implementationOf(object item) + { + return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T))); } - public List get_collection(string name, string scope = null) - { - - return (from c in findObjects(name, scope) where c.GetType().IsSubclassOf(typeof(T)) select (T)c).ToList(); - - } - public List get_collection_ref(string name) { return get_collection(name); diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index b43663c0..fedb1a27 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -229,8 +229,6 @@ namespace Tensorflow return get_variable_scope_store().current_scope; } - - // TODO: Misses RefVariable as possible value type? public static _VariableScopeStore get_variable_scope_store() { var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault();