| @@ -288,7 +288,7 @@ namespace Tensorflow | |||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| } | } | ||||
| public object get_collection(string name) | |||||
| public object get_collection(string name, string scope = "") | |||||
| { | { | ||||
| return _collections.ContainsKey(name) ? _collections[name] : null; | return _collections.ContainsKey(name) ? _collections[name] : null; | ||||
| } | } | ||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class control_flow_ops | |||||
| { | |||||
| public static Operation group(Operation[] inputs) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,17 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Mapping C# functions to Python | |||||
| /// </summary> | |||||
| public class Python | |||||
| { | |||||
| protected void print(object obj) | |||||
| { | |||||
| Console.WriteLine(obj.ToString()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -12,7 +12,10 @@ namespace Tensorflow | |||||
| public bool _trainable; | public bool _trainable; | ||||
| public Tensor _variable; | public Tensor _variable; | ||||
| public Tensor _snapshot; | public Tensor _snapshot; | ||||
| public Operation op; | |||||
| private Operation _initializer_op; | |||||
| public Operation initializer => _initializer_op; | |||||
| public Operation op => _initializer_op; | |||||
| public RefVariable(object initial_value, | public RefVariable(object initial_value, | ||||
| bool trainable = true, | bool trainable = true, | ||||
| @@ -81,7 +84,7 @@ namespace Tensorflow | |||||
| // have an issue if these other variables aren't initialized first by | // have an issue if these other variables aren't initialized first by | ||||
| // using their initialized_value() method. | // using their initialized_value() method. | ||||
| var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; | |||||
| _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op; | |||||
| if (!String.IsNullOrEmpty(caching_device)) | if (!String.IsNullOrEmpty(caching_device)) | ||||
| { | { | ||||
| @@ -92,7 +95,6 @@ namespace Tensorflow | |||||
| _snapshot = gen_array_ops.identity(_variable, name = "read"); | _snapshot = gen_array_ops.identity(_variable, name = "read"); | ||||
| } | } | ||||
| op = _initializer_op; | |||||
| ops.add_to_collections(collections, this); | ops.add_to_collections(collections, this); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static Operation global_variables_initializer() | |||||
| { | |||||
| var g = variables.global_variables(); | |||||
| return variables.variables_initializer(g as RefVariable[]); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -14,5 +15,32 @@ namespace Tensorflow | |||||
| { | { | ||||
| return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns global variables. | |||||
| /// </summary> | |||||
| /// <param name="scope"> | |||||
| /// (Optional.) A string. If supplied, the resulting list is filtered | |||||
| /// to include only items whose `name` attribute matches `scope` using | |||||
| /// `re.match`. Items without a `name` attribute are never returned if a | |||||
| /// scope is supplied. The choice of `re.match` means that a `scope` without | |||||
| /// special tokens filters by prefix. | |||||
| /// </param> | |||||
| /// <returns>A list of `Variable` objects.</returns> | |||||
| public static object global_variables(string scope = "") | |||||
| { | |||||
| return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope); | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns an Op that initializes a list of variables. | |||||
| /// </summary> | |||||
| /// <param name="var_list">List of `Variable` objects to initialize.</param> | |||||
| /// <param name="name">Optional name for the returned operation.</param> | |||||
| /// <returns>An Op that run the initializers of all the specified variables.</returns> | |||||
| public static Operation variables_initializer(RefVariable[] var_list, string name = "init") | |||||
| { | |||||
| return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,9 +24,23 @@ namespace Tensorflow | |||||
| graph.add_to_collections(names, value); | graph.add_to_collections(names, value); | ||||
| } | } | ||||
| public static object get_collection(string key) | |||||
| /// <summary> | |||||
| /// Wrapper for `Graph.get_collection()` using the default graph. | |||||
| /// contains many standard names for collections. | |||||
| /// </summary> | |||||
| /// <param name="key"> | |||||
| /// The key for the collection. For example, the `GraphKeys` class | |||||
| /// </param> | |||||
| /// <param name="scope"></param> | |||||
| /// <returns> | |||||
| /// The list of values in the collection with the given `name`, or | |||||
| /// an empty list if no value has been added to that collection. The | |||||
| /// list contains the values in the order under which they were | |||||
| /// collected. | |||||
| /// </returns> | |||||
| public static object get_collection(string key, string scope = "") | |||||
| { | { | ||||
| return get_default_graph().get_collection(key); | |||||
| return get_default_graph().get_collection(key, scope); | |||||
| } | } | ||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| @@ -7,7 +7,7 @@ using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class VariableTest | |||||
| public class VariableTest : Python | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void StringVar() | public void StringVar() | ||||
| @@ -22,5 +22,28 @@ namespace TensorFlowNET.UnitTest | |||||
| var x = tf.Variable(3); | var x = tf.Variable(3); | ||||
| var y = tf.Variable(6f); | var y = tf.Variable(6f); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://databricks.com/tensorflow/variables | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void Add() | |||||
| { | |||||
| var x = tf.Variable(0, name: "x"); | |||||
| var model = tf.global_variables_initializer(); | |||||
| using (var session = tf.Session()) | |||||
| { | |||||
| /*session.run(model); | |||||
| for(int i = 0; i < 5; i++) | |||||
| { | |||||
| x = x + 1; | |||||
| var result = session.run(x); | |||||
| print(result); | |||||
| }*/ | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||