From b87081cc4ac47dd3ad28f40cd899a228c0fc3552 Mon Sep 17 00:00:00 2001
From: "Mascha, Philipp"
Date: Thu, 14 Nov 2019 14:07:59 +0100
Subject: [PATCH] Made internal collection of graph a list.
---
src/TensorFlowNET.Core/APIs/tf.variable.cs | 4 +-
.../Framework/meta_graph.cs | 30 ++++-----
src/TensorFlowNET.Core/Graphs/Graph.cs | 67 ++++++-------------
src/TensorFlowNET.Core/Summaries/Summary.cs | 7 +-
.../Variables/variable_scope.py.cs | 43 +++++-------
.../Variables/variables.py.cs | 12 +---
src/TensorFlowNET.Core/ops.cs | 2 +-
.../Keras/EmbeddingTest.cs | 1 -
test/TensorFlowNET.UnitTest/VariableTest.cs | 5 +-
9 files changed, 61 insertions(+), 110 deletions(-)
diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs
index 179cedee..1d07747b 100644
--- a/src/TensorFlowNET.Core/APIs/tf.variable.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs
@@ -23,14 +23,14 @@ namespace Tensorflow
{
public VariableV1[] global_variables(string scope = null)
{
- return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope) as List)
+ return (ops.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope))
.ToArray();
}
public Operation global_variables_initializer()
{
var g = variables.global_variables();
- return variables.variables_initializer(g.ToArray());
+ return variables.variables_initializer(g?.ToArray());
}
///
diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs
index 05092581..21684001 100644
--- a/src/TensorFlowNET.Core/Framework/meta_graph.cs
+++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs
@@ -264,26 +264,22 @@ namespace Tensorflow
if (!meta_graph_def.CollectionDef.ContainsKey(key))
meta_graph_def.CollectionDef[key] = new CollectionDef();
var col_def = meta_graph_def.CollectionDef[key];
-
- switch (graph.get_collection(key))
+ col_def.NodeList = new Types.NodeList();
+ col_def.BytesList = new Types.BytesList();
+ foreach (object value in graph.get_collection(key))
{
- case List collection_list:
- col_def.BytesList = new Types.BytesList();
- foreach (var x in collection_list)
- {
+ switch (value)
+ {
+ case RefVariable x:
var proto = x.to_proto(export_scope);
col_def.BytesList.Value.Add(proto.ToByteString());
- }
-
- break;
- case List
private bool _finalized = false;
+
///
- /// Arbitrary collections of objects.
+ /// Arbitrary collections of objects inside the graph.
+ /// TODO: Access might be slow (-> O(n)) depending on size.
///
- private Dictionary _collections = new Dictionary();
+ private readonly ICollection<(string name, string scope, object item)> _collections = new List<(string name, string scope, object item)>();
public bool building_function;
@@ -228,16 +230,14 @@ namespace Tensorflow
throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
}
- public void add_to_collection(string name, T value)
+ public void add_to_collection(string name, object value)
{
_check_not_finalized();
- if (_collections.ContainsKey(name))
- (_collections[name] as List).Add(value);
- else
- _collections[name] = new List { value };
+ _collections.Add((name, null, value));
}
- public void add_to_collections(List names, T value)
+
+ public void add_to_collections(List names, object value)
{
foreach (string name in names)
add_to_collection(name, value);
@@ -278,12 +278,6 @@ namespace Tensorflow
_create_op_helper(op, true);
- /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
- Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
- Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
- Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
- Console.WriteLine();*/
-
return op;
}
@@ -422,46 +416,29 @@ namespace Tensorflow
public string[] get_all_collection_keys()
{
- return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
+ return (from c in _collections where !c.name.StartsWith("__") select c.name).ToArray();
}
- public object get_collection(string name, string scope = null)
+ public List