From ca1fa35b8ef697ad5c64e041155dee481ae29c8a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 23 Jan 2019 23:03:32 -0600 Subject: [PATCH] Add VariableTest.Add test in UnitTest --- src/TensorFlowNET.Core/Graphs/Graph.cs | 2 +- .../Operations/control_flow_ops.py.cs | 14 ++++++++++ src/TensorFlowNET.Core/Python.cs | 17 +++++++++++ .../Variables/RefVariable.cs | 8 ++++-- .../Variables/tf.variable.cs | 15 ++++++++++ .../Variables/variables.py.cs | 28 +++++++++++++++++++ src/TensorFlowNET.Core/ops.py.cs | 18 ++++++++++-- test/TensorFlowNET.UnitTest/VariableTest.cs | 25 ++++++++++++++++- 8 files changed, 120 insertions(+), 7 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs create mode 100644 src/TensorFlowNET.Core/Python.cs create mode 100644 src/TensorFlowNET.Core/Variables/tf.variable.cs diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index f8ef8900..512e08c2 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -288,7 +288,7 @@ namespace Tensorflow return _nodes_by_name.Values.Select(x => x).ToArray(); } - public object get_collection(string name) + public object get_collection(string name, string scope = "") { return _collections.ContainsKey(name) ? _collections[name] : null; } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs new file mode 100644 index 00000000..544273ba --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class control_flow_ops + { + public static Operation group(Operation[] inputs) + { + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs new file mode 100644 index 00000000..2e4d6423 --- /dev/null +++ b/src/TensorFlowNET.Core/Python.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Mapping C# functions to Python + /// + public class Python + { + protected void print(object obj) + { + Console.WriteLine(obj.ToString()); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index d7e1c980..faf38994 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -12,7 +12,10 @@ namespace Tensorflow public bool _trainable; public Tensor _variable; public Tensor _snapshot; - public Operation op; + + private Operation _initializer_op; + public Operation initializer => _initializer_op; + public Operation op => _initializer_op; public RefVariable(object initial_value, bool trainable = true, @@ -81,7 +84,7 @@ namespace Tensorflow // have an issue if these other variables aren't initialized first by // using their initialized_value() method. - var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; + _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; if (!String.IsNullOrEmpty(caching_device)) { @@ -92,7 +95,6 @@ namespace Tensorflow _snapshot = gen_array_ops.identity(_variable, name = "read"); } - op = _initializer_op; ops.add_to_collections(collections, this); } } diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs new file mode 100644 index 00000000..f3f960e2 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static Operation global_variables_initializer() + { + var g = variables.global_variables(); + return variables.variables_initializer(g as RefVariable[]); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 1e9c426b..596cc5ab 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -14,5 +15,32 @@ namespace Tensorflow { return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); } + + /// + /// Returns global variables. + /// + /// + /// (Optional.) A string. If supplied, the resulting list is filtered + /// to include only items whose `name` attribute matches `scope` using + /// `re.match`. Items without a `name` attribute are never returned if a + /// scope is supplied. The choice of `re.match` means that a `scope` without + /// special tokens filters by prefix. + /// + /// A list of `Variable` objects. + public static object global_variables(string scope = "") + { + return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); + } + + /// + /// Returns an Op that initializes a list of variables. + /// + /// List of `Variable` objects to initialize. + /// Optional name for the returned operation. + /// An Op that run the initializers of all the specified variables. + public static Operation variables_initializer(RefVariable[] var_list, string name = "init") + { + return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray()); + } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index feda8bb6..6fbc4699 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -24,9 +24,23 @@ namespace Tensorflow graph.add_to_collections(names, value); } - public static object get_collection(string key) + /// + /// Wrapper for `Graph.get_collection()` using the default graph. + /// contains many standard names for collections. + /// + /// + /// The key for the collection. For example, the `GraphKeys` class + /// + /// + /// + /// The list of values in the collection with the given `name`, or + /// an empty list if no value has been added to that collection. The + /// list contains the values in the order under which they were + /// collected. + /// + public static object get_collection(string key, string scope = "") { - return get_default_graph().get_collection(key); + return get_default_graph().get_collection(key, scope); } public static Graph get_default_graph() diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 7b8f1ded..f316c3f0 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -7,7 +7,7 @@ using Tensorflow; namespace TensorFlowNET.UnitTest { [TestClass] - public class VariableTest + public class VariableTest : Python { [TestMethod] public void StringVar() @@ -22,5 +22,28 @@ namespace TensorFlowNET.UnitTest var x = tf.Variable(3); var y = tf.Variable(6f); } + + /// + /// https://databricks.com/tensorflow/variables + /// + [TestMethod] + public void Add() + { + var x = tf.Variable(0, name: "x"); + + var model = tf.global_variables_initializer(); + + using (var session = tf.Session()) + { + /*session.run(model); + for(int i = 0; i < 5; i++) + { + x = x + 1; + var result = session.run(x); + print(result); + }*/ + } + + } } }