Browse Source

Add VariableTest.Add test in UnitTest

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
ca1fa35b8e
8 changed files with 120 additions and 7 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +14
    -0
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  3. +17
    -0
      src/TensorFlowNET.Core/Python.cs
  4. +5
    -3
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  5. +15
    -0
      src/TensorFlowNET.Core/Variables/tf.variable.cs
  6. +28
    -0
      src/TensorFlowNET.Core/Variables/variables.py.cs
  7. +16
    -2
      src/TensorFlowNET.Core/ops.py.cs
  8. +24
    -1
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -288,7 +288,7 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray(); return _nodes_by_name.Values.Select(x => x).ToArray();
} }


public object get_collection(string name)
public object get_collection(string name, string scope = "")
{ {
return _collections.ContainsKey(name) ? _collections[name] : null; return _collections.ContainsKey(name) ? _collections[name] : null;
} }


+ 14
- 0
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class control_flow_ops
{
public static Operation group(Operation[] inputs)
{
return null;
}
}
}

+ 17
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Mapping C# functions to Python
/// </summary>
public class Python
{
protected void print(object obj)
{
Console.WriteLine(obj.ToString());
}
}
}

+ 5
- 3
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -12,7 +12,10 @@ namespace Tensorflow
public bool _trainable; public bool _trainable;
public Tensor _variable; public Tensor _variable;
public Tensor _snapshot; public Tensor _snapshot;
public Operation op;

private Operation _initializer_op;
public Operation initializer => _initializer_op;
public Operation op => _initializer_op;


public RefVariable(object initial_value, public RefVariable(object initial_value,
bool trainable = true, bool trainable = true,
@@ -81,7 +84,7 @@ namespace Tensorflow
// have an issue if these other variables aren't initialized first by // have an issue if these other variables aren't initialized first by
// using their initialized_value() method. // using their initialized_value() method.


var _initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;
_initializer_op = gen_state_ops.assign(_variable, _initial_value, validate_shape).op;


if (!String.IsNullOrEmpty(caching_device)) if (!String.IsNullOrEmpty(caching_device))
{ {
@@ -92,7 +95,6 @@ namespace Tensorflow
_snapshot = gen_array_ops.identity(_variable, name = "read"); _snapshot = gen_array_ops.identity(_variable, name = "read");
} }


op = _initializer_op;
ops.add_to_collections(collections, this); ops.add_to_collections(collections, this);
} }
} }


+ 15
- 0
src/TensorFlowNET.Core/Variables/tf.variable.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static Operation global_variables_initializer()
{
var g = variables.global_variables();
return variables.variables_initializer(g as RefVariable[]);
}
}
}

+ 28
- 0
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow namespace Tensorflow
@@ -14,5 +15,32 @@ namespace Tensorflow
{ {
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES); return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES);
} }

/// <summary>
/// Returns global variables.
/// </summary>
/// <param name="scope">
/// (Optional.) A string. If supplied, the resulting list is filtered
/// to include only items whose `name` attribute matches `scope` using
/// `re.match`. Items without a `name` attribute are never returned if a
/// scope is supplied. The choice of `re.match` means that a `scope` without
/// special tokens filters by prefix.
/// </param>
/// <returns>A list of `Variable` objects.</returns>
public static object global_variables(string scope = "")
{
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope);
}

/// <summary>
/// Returns an Op that initializes a list of variables.
/// </summary>
/// <param name="var_list">List of `Variable` objects to initialize.</param>
/// <param name="name">Optional name for the returned operation.</param>
/// <returns>An Op that run the initializers of all the specified variables.</returns>
public static Operation variables_initializer(RefVariable[] var_list, string name = "init")
{
return control_flow_ops.group(var_list.Select(x => x.initializer).ToArray());
}
} }
} }

+ 16
- 2
src/TensorFlowNET.Core/ops.py.cs View File

@@ -24,9 +24,23 @@ namespace Tensorflow
graph.add_to_collections(names, value); graph.add_to_collections(names, value);
} }


public static object get_collection(string key)
/// <summary>
/// Wrapper for `Graph.get_collection()` using the default graph.
/// contains many standard names for collections.
/// </summary>
/// <param name="key">
/// The key for the collection. For example, the `GraphKeys` class
/// </param>
/// <param name="scope"></param>
/// <returns>
/// The list of values in the collection with the given `name`, or
/// an empty list if no value has been added to that collection. The
/// list contains the values in the order under which they were
/// collected.
/// </returns>
public static object get_collection(string key, string scope = "")
{ {
return get_default_graph().get_collection(key);
return get_default_graph().get_collection(key, scope);
} }


public static Graph get_default_graph() public static Graph get_default_graph()


+ 24
- 1
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -7,7 +7,7 @@ using Tensorflow;
namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
[TestClass] [TestClass]
public class VariableTest
public class VariableTest : Python
{ {
[TestMethod] [TestMethod]
public void StringVar() public void StringVar()
@@ -22,5 +22,28 @@ namespace TensorFlowNET.UnitTest
var x = tf.Variable(3); var x = tf.Variable(3);
var y = tf.Variable(6f); var y = tf.Variable(6f);
} }

/// <summary>
/// https://databricks.com/tensorflow/variables
/// </summary>
[TestMethod]
public void Add()
{
var x = tf.Variable(0, name: "x");

var model = tf.global_variables_initializer();

using (var session = tf.Session())
{
/*session.run(model);
for(int i = 0; i < 5; i++)
{
x = x + 1;
var result = session.run(x);
print(result);
}*/
}

}
} }
} }

Loading…
Cancel
Save