Browse Source

Made internal collection of graph a list.

tags/v0.13
Mascha, Philipp Haiping Chen 6 years ago
parent
commit
b87081cc4a
9 changed files with 61 additions and 110 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  2. +13
    -17
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  3. +22
    -45
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +2
    -5
      src/TensorFlowNET.Core/Summaries/Summary.cs
  5. +15
    -28
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  6. +3
    -9
      src/TensorFlowNET.Core/Variables/variables.py.cs
  7. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  8. +0
    -1
      test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs
  9. +3
    -2
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -23,14 +23,14 @@ namespace Tensorflow
{
public VariableV1[] global_variables(string scope = null)
{
return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List<VariableV1>)
return (ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope))
.ToArray();
}

public Operation global_variables_initializer()
{
var g = variables.global_variables();
return variables.variables_initializer(g.ToArray());
return variables.variables_initializer(g?.ToArray());
}

/// <summary>


+ 13
- 17
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -264,26 +264,22 @@ namespace Tensorflow
if (!meta_graph_def.CollectionDef.ContainsKey(key))
meta_graph_def.CollectionDef[key] = new CollectionDef();
var col_def = meta_graph_def.CollectionDef[key];

switch (graph.get_collection(key))
col_def.NodeList = new Types.NodeList();
col_def.BytesList = new Types.BytesList();
foreach (object value in graph.get_collection(key))
{
case List<RefVariable> collection_list:
col_def.BytesList = new Types.BytesList();
foreach (var x in collection_list)
{
switch (value)
{
case RefVariable x:
var proto = x.to_proto(export_scope);
col_def.BytesList.Value.Add(proto.ToByteString());
}
break;
case List<object> collection_list:
col_def.NodeList = new Types.NodeList();
foreach (var x in collection_list)
if (x is ITensorOrOperation x2)
col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
break;
case List<Operation> collection_list:
break;
break;
case ITensorOrOperation x2:
col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
break;
default:
break;
}
}
}



+ 22
- 45
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -100,10 +100,12 @@ namespace Tensorflow
/// </summary>
private bool _finalized = false;


/// <summary>
/// Arbitrary collections of objects.
/// Arbitrary collections of objects inside the graph.
/// TODO: Access might be slow (-> O(n)) depending on size.
/// </summary>
private Dictionary<string, object> _collections = new Dictionary<string, object>();
private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>();

public bool building_function;
@@ -228,16 +230,14 @@ namespace Tensorflow
throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
}

public void add_to_collection<T>(string name, T value)
public void add_to_collection(string name, object value)
{
_check_not_finalized();
if (_collections.ContainsKey(name))
(_collections[name] as List<T>).Add(value);
else
_collections[name] = new List<T> { value };
_collections.Add((name, null, value));
}

public void add_to_collections<T>(List<string> names, T value)

public void add_to_collections(List<string> names, object value)
{
foreach (string name in names)
add_to_collection(name, value);
@@ -278,12 +278,6 @@ namespace Tensorflow

_create_op_helper(op, true);

/*Console.Write($"create_op: {op_type} '{node_def.Name}'");
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
Console.WriteLine();*/

return op;
}

@@ -422,46 +416,29 @@ namespace Tensorflow

public string[] get_all_collection_keys()
{
return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray();
}

public object get_collection(string name, string scope = null)
public List<object> get_collection(string name, string scope = null)
{
return _collections.ContainsKey(name) ? _collections[name] : null;
}

return get_collection<object>(name, scope);
}
private IEnumerable<object> findObjects(string name, string scope)
{
return (from c in _collections where c.name == name && (scope == null || c.scope == scope) select c.item);
}
public List<T> get_collection<T>(string name, string scope = null)
{
List<T> t = default;
var collection = _collections.ContainsKey(name) ? _collections[name] : new List<T>();
switch (collection)
{
case List<VariableV1> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<ResourceVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<RefVariable> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<Tensor> list:
t = list.Select(x => (T)(object)x).ToList();
break;
case List<Operation> list:
t = list.Select(x => (T)(object)x).ToList();
break;
default:
throw new NotImplementedException($"get_collection<{typeof(T).FullName}>");
}
return t;
return (from c in findObjects(name, scope) where c.GetType().IsSubclassOf(typeof(T)) select (T)c).ToList();
}

public List<T> get_collection_ref<T>(string name)
{
if (!_collections.ContainsKey(name))
_collections[name] = new List<T>();
return _collections[name] as List<T>;
return get_collection<T>(name);
}

public void prevent_feeding(Tensor tensor)


+ 2
- 5
src/TensorFlowNET.Core/Summaries/Summary.cs View File

@@ -39,11 +39,8 @@ namespace Tensorflow.Summaries

public Tensor merge_all(string key = "summaries", string scope= null, string name= null)
{
var summary_ops = ops.get_collection(key, scope: scope);
if (summary_ops == null)
return null;
else
return merge((summary_ops as List<ITensorOrOperation>).Select(x => x as Tensor).ToArray(), name: name);
var summary_ops = ops.get_collection<ITensorOrOperation>(key, scope: scope);
return merge(summary_ops.Select(x => x as Tensor).ToArray(), name: name);
}

/// <summary>


+ 15
- 28
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -215,13 +215,13 @@ namespace Tensorflow

public static _VariableStore _get_default_variable_store()
{
var store = ops.get_collection(_VARSTORE_KEY);
if (store != null)
return (store as List<_VariableStore>)[0];
var store1 = new _VariableStore();
ops.add_to_collection(_VARSTORE_KEY, store1);
return store1;
var store = ops.get_collection<_VariableStore>(_VARSTORE_KEY).FirstOrDefault();
if (store == null)
{
store = new _VariableStore();
ops.add_to_collection(_VARSTORE_KEY, store);
}
return store;
}

public static VariableScope get_variable_scope()
@@ -229,32 +229,19 @@ namespace Tensorflow
return get_variable_scope_store().current_scope;
}


// TODO: Misses RefVariable as possible value type?
public static _VariableScopeStore get_variable_scope_store()
{
_VariableScopeStore ret = null;
var scope_store = ops.get_collection(_VARSCOPESTORE_KEY);
var scope_store = ops.get_collection<_VariableScopeStore>(_VARSCOPESTORE_KEY).FirstOrDefault();
if (scope_store == null)
scope_store = ops.get_collection<RefVariable>(_VARSCOPESTORE_KEY).FirstOrDefault();
if (scope_store == null)
{
ret = new _VariableScopeStore();
ops.add_to_collection(_VARSCOPESTORE_KEY, ret);
}
else
{
switch (scope_store)
{
case List<RefVariable> values:
ret = values[0];
break;
case List<_VariableScopeStore> values:
ret = values[0];
break;
default:
throw new InvalidOperationException("get_variable_scope_store");
}
scope_store = new _VariableScopeStore();
ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store);
}

return ret;
return scope_store;
}

public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true)


+ 3
- 9
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -41,13 +41,8 @@ namespace Tensorflow
{
var all = new List<VariableV1>();

var collection = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope);
if(collection != null)
all.AddRange(collection as List<VariableV1>);

collection = ops.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS, scope);
if (collection != null)
all.AddRange(collection as List<VariableV1>);
all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope));
all.AddRange(ops.get_collection<VariableV1>(tf.GraphKeys.SAVEABLE_OBJECTS, scope));

return all.ToArray();
}
@@ -65,9 +60,8 @@ namespace Tensorflow
/// <returns>A list of `Variable` objects.</returns>
public static List<VariableV1> global_variables(string scope = null)
{
var result = ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope);
return ops.get_collection<VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope);

return result == null ? new List<VariableV1>() : result as List<VariableV1>;
}

/// <summary>


+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -63,7 +63,7 @@ namespace Tensorflow
/// list contains the values in the order under which they were
/// collected.
/// </returns>
public static object get_collection(string key, string scope = null)
public static List<object> get_collection(string key, string scope = null)
{
return get_default_graph().get_collection(key, scope);
}


+ 0
- 1
test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs View File

@@ -14,7 +14,6 @@ namespace TensorFlowNET.UnitTest.Keras
[TestClass]
public class EmbeddingTest
{
[Ignore]
[TestMethod]
public void Embedding()
{


+ 3
- 2
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;
@@ -41,7 +42,7 @@ namespace TensorFlowNET.UnitTest
tf_with(tf.variable_scope("bar"), delegate
{
var v = tf.get_variable("v", new TensorShape(1));
Assert.AreEqual(v.name, "foo/bar/v:0");
v.name.Should().Be("foo/bar/v:0");
});
});
}


Loading…
Cancel
Save