Browse Source

Fixed error in type checking for generic get_collection<T>.

tags/v0.13
Mascha, Philipp Haiping Chen 6 years ago
parent
commit
acf8fbd10d
2 changed files with 14 additions and 11 deletions
  1. +14
    -9
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +0
    -2
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs

+ 14
- 9
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -424,18 +424,23 @@ namespace Tensorflow
return get_collection<object>(name, scope); return get_collection<object>(name, scope);
} }
private IEnumerable<object> findObjects(string name, string scope)
public List<T> get_collection<T>(string name, string scope = null)
{ {
return (from c in _collections where c.name == name && (scope == null || c.scope == scope) select c.item);
return (from c in _collections
where c.name == name &&
(scope == null || c.scope == scope) &&
implementationOf<T>(c.item)
select (T)(c.item)).ToList();
}
private static bool implementationOf<T>(object item)
{
return (item.GetType() == typeof(T) || item.GetType().IsSubclassOf(typeof(T)));
} }
public List<T> get_collection<T>(string name, string scope = null)
{
return (from c in findObjects(name, scope) where c.GetType().IsSubclassOf(typeof(T)) select (T)c).ToList();
}

public List<T> get_collection_ref<T>(string name) public List<T> get_collection_ref<T>(string name)
{ {
return get_collection<T>(name); return get_collection<T>(name);


+ 0
- 2
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -229,8 +229,6 @@ namespace Tensorflow
return get_variable_scope_store().current_scope; return get_variable_scope_store().current_scope;
} }



// TODO: Misses RefVariable as possible value type?
public static _VariableScopeStore get_variable_scope_store() public static _VariableScopeStore get_variable_scope_store()
{ {
var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault(); var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault();


Loading…
Cancel
Save