diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index f8ef8900..512e08c2 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -288,7 +288,7 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray();
}
- public object get_collection(string name)
+ public object get_collection(string name, string scope = "")
{
return _collections.ContainsKey(name) ? _collections[name] : null;
}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
new file mode 100644
index 00000000..544273ba
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class control_flow_ops
+ {
+ public static Operation group(Operation[] inputs)
+ {
+ return null;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs
new file mode 100644
index 00000000..2e4d6423
--- /dev/null
+++ b/src/TensorFlowNET.Core/Python.cs
@@ -0,0 +1,17 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// Mapping C# functions to Python
+ ///
+ public class Python
+ {
+ protected void print(object obj)
+ {
+ Console.WriteLine(obj.ToString());
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index d7e1c980..faf38994 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -12,7 +12,10 @@ namespace Tensorflow
public bool _trainable;
public Tensor _variable;
public Tensor _snapshot;
- public Operation op;
+
+ private Operation _initializer_op;
+ public Operation initializer => _initializer_op;
+ public Operation op => _initializer_op;
public RefVariable(object initial_value,
bool trainable = true,
@@ -81,7 +84,7 @@ namespace Tensorflow
// have an issue if these other variables aren't initialized first by
// using their initialized_value() method.
- var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
+ _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
if (!String.IsNullOrEmpty(caching_device))
{
@@ -92,7 +95,6 @@ namespace Tensorflow
_snapshot = gen_array_ops.identity(_variable, name = "read");
}
- op = _initializer_op;
ops.add_to_collections(collections, this);
}
}
diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs
new file mode 100644
index 00000000..f3f960e2
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public static partial class tf
+ {
+ public static Operation global_variables_initializer()
+ {
+ var g = variables.global_variables();
+ return variables.variables_initializer(g as RefVariable[]);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs
index 1e9c426b..596cc5ab 100644
--- a/src/TensorFlowNET.Core/Variables/variables.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variables.py.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
namespace Tensorflow
@@ -14,5 +15,32 @@ namespace Tensorflow
{
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
}
+
+ ///
+ /// Returns global variables.
+ ///
+ ///
+ /// (Optional.) A string. If supplied, the resulting list is filtered
+ /// to include only items whose `name` attribute matches `scope` using
+ /// `re.match`. Items without a `name` attribute are never returned if a
+ /// scope is supplied. The choice of `re.match` means that a `scope` without
+ /// special tokens filters by prefix.
+ ///
+ /// A list of `Variable` objects.
+ public static object global_variables(string scope = "")
+ {
+ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
+ }
+
+ ///
+ /// Returns an Op that initializes a list of variables.
+ ///
+ /// List of `Variable` objects to initialize.
+ /// Optional name for the returned operation.
+ /// An Op that run the initializers of all the specified variables.
+ public static Operation variables_initializer(RefVariable[] var_list, string name = "init")
+ {
+ return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray());
+ }
}
}
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index feda8bb6..6fbc4699 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -24,9 +24,23 @@ namespace Tensorflow
graph.add_to_collections(names, value);
}
- public static object get_collection(string key)
+ ///
+ /// Wrapper for `Graph.get_collection()` using the default graph.
+ /// contains many standard names for collections.
+ ///
+ ///
+ /// The key for the collection. For example, the `GraphKeys` class
+ ///
+ ///
+ ///
+ /// The list of values in the collection with the given `name`, or
+ /// an empty list if no value has been added to that collection. The
+ /// list contains the values in the order under which they were
+ /// collected.
+ ///
+ public static object get_collection(string key, string scope = "")
{
- return get_default_graph().get_collection(key);
+ return get_default_graph().get_collection(key, scope);
}
public static Graph get_default_graph()
diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs
index 7b8f1ded..f316c3f0 100644
--- a/test/TensorFlowNET.UnitTest/VariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/VariableTest.cs
@@ -7,7 +7,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest
{
[TestClass]
- public class VariableTest
+ public class VariableTest : Python
{
[TestMethod]
public void StringVar()
@@ -22,5 +22,28 @@ namespace TensorFlowNET.UnitTest
var x = tf.Variable(3);
var y = tf.Variable(6f);
}
+
+ ///
+ /// https://databricks.com/tensorflow/variables
+ ///
+ [TestMethod]
+ public void Add()
+ {
+ var x = tf.Variable(0, name: "x");
+
+ var model = tf.global_variables_initializer();
+
+ using (var session = tf.Session())
+ {
+ /*session.run(model);
+ for(int i = 0; i < 5; i++)
+ {
+ x = x + 1;
+ var result = session.run(x);
+ print(result);
+ }*/
+ }
+
+ }
}
}