| @@ -25,6 +25,7 @@ namespace Tensorflow | |||||
| /// size_t* => ref uint | /// size_t* => ref uint | ||||
| /// void* => IntPtr | /// void* => IntPtr | ||||
| /// string => IntPtr c_api.StringPiece(IntPtr) | /// string => IntPtr c_api.StringPiece(IntPtr) | ||||
| /// unsigned char => byte | |||||
| /// </summary> | /// </summary> | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| @@ -0,0 +1,19 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class ValueError : Exception | |||||
| { | |||||
| public ValueError() : base() | |||||
| { | |||||
| } | |||||
| public ValueError(string message) : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -27,6 +27,11 @@ namespace Tensorflow | |||||
| public string _graph_key; | public string _graph_key; | ||||
| public Status Status { get; } | public Status Status { get; } | ||||
| /// <summary> | |||||
| /// Arbitrary collections of objects. | |||||
| /// </summary> | |||||
| private Dictionary<string, object> _collections = new Dictionary<string, object>(); | |||||
| public Graph() | public Graph() | ||||
| { | { | ||||
| _handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
| @@ -86,6 +91,11 @@ namespace Tensorflow | |||||
| throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | ||||
| } | } | ||||
| public void add_to_collection(string name, object value) | |||||
| { | |||||
| _collections[name] = value; | |||||
| } | |||||
| public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | ||||
| TF_DataType[] input_types = null, string name = "", | TF_DataType[] input_types = null, string name = "", | ||||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
| @@ -221,6 +231,11 @@ namespace Tensorflow | |||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| } | } | ||||
| public Dictionary<string, object> get_collection(string name) | |||||
| { | |||||
| return _collections; | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| c_api.TF_DeleteGraph(_handle); | c_api.TF_DeleteGraph(_handle); | ||||
| @@ -49,6 +49,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| var var_list = variables.trainable_variables(); | |||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class RefVariable : Variable | |||||
| public class RefVariable : VariableV1 | |||||
| { | { | ||||
| public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
| public Tensor _initial_value; | public Tensor _initial_value; | ||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class VariableScope | |||||
| { | |||||
| public bool? use_resource { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public enum VariableSynchronization | |||||
| { | |||||
| AUTO = 0, | |||||
| NONE = 1, | |||||
| ON_WRITE = 2, | |||||
| ON_READ = 3 | |||||
| } | |||||
| } | |||||
| @@ -14,9 +14,9 @@ namespace Tensorflow | |||||
| /// the variable are fixed. The value can be changed using one of the assign methods. | /// the variable are fixed. The value can be changed using one of the assign methods. | ||||
| /// https://tensorflow.org/guide/variables | /// https://tensorflow.org/guide/variables | ||||
| /// </summary> | /// </summary> | ||||
| public class Variable | |||||
| public class VariableV1 | |||||
| { | { | ||||
| public Variable(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||||
| public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||||
| { | { | ||||
| } | } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class _VariableScopeStore | |||||
| { | |||||
| public VariableScope current_scope { get; set; } | |||||
| public _VariableScopeStore() | |||||
| { | |||||
| current_scope = new VariableScope(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,74 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class variable_scope | |||||
| { | |||||
| public static string _VARSCOPESTORE_KEY = "__varscope"; | |||||
| public static bool _DEFAULT_USE_RESOURCE = false; | |||||
| public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO) | |||||
| { | |||||
| var trainable = _get_trainable_value(synchronization); | |||||
| if (!use_resource.HasValue) | |||||
| { | |||||
| use_resource = get_variable_scope().use_resource; | |||||
| } | |||||
| if(!use_resource.HasValue) | |||||
| use_resource = _DEFAULT_USE_RESOURCE; | |||||
| if (use_resource.Value) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| return new RefVariable(initial_value); | |||||
| } | |||||
| } | |||||
| public static VariableScope get_variable_scope() | |||||
| { | |||||
| return get_variable_scope_store().current_scope; | |||||
| } | |||||
| public static _VariableScopeStore get_variable_scope_store() | |||||
| { | |||||
| var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); | |||||
| if (scope_store == null) | |||||
| { | |||||
| scope_store = new _VariableScopeStore(); | |||||
| ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); | |||||
| } | |||||
| else | |||||
| { | |||||
| // scope_store = scope_store[0]; | |||||
| } | |||||
| return scope_store; | |||||
| } | |||||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | |||||
| { | |||||
| if(synchronization == VariableSynchronization.ON_READ) | |||||
| { | |||||
| if (trainable.Value) | |||||
| throw new ValueError("Synchronization value can be set to " + | |||||
| "VariableSynchronization.ON_READ only for non-trainable variables. " + | |||||
| "You have specified trainable=True and " + | |||||
| "synchronization=VariableSynchronization.ON_READ."); | |||||
| else | |||||
| trainable = false; | |||||
| } | |||||
| else if (!trainable.HasValue) | |||||
| { | |||||
| trainable = true; | |||||
| } | |||||
| return trainable.Value; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class variables | |||||
| { | |||||
| /// <summary> | |||||
| /// Returns all variables created with `trainable=True` | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static object trainable_variables() | |||||
| { | |||||
| return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class ops | |||||
| { | |||||
| /// <summary> | |||||
| /// Standard names to use for graph collections. | |||||
| /// The standard library uses various well-known names to collect and | |||||
| /// retrieve values associated with a graph. For example, the | |||||
| /// `tf.Optimizer` subclasses default to optimizing the variables | |||||
| /// collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is | |||||
| /// specified, but it is also possible to pass an explicit list of | |||||
| /// variables. | |||||
| /// </summary> | |||||
| public static class GraphKey | |||||
| { | |||||
| /// <summary> | |||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | |||||
| /// </summary> | |||||
| public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -10,8 +10,19 @@ using System.Linq; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public static class ops | |||||
| public partial class ops | |||||
| { | { | ||||
| public static void add_to_collection(string name, object value) | |||||
| { | |||||
| var graph = tf.get_default_graph(); | |||||
| graph.add_to_collection(name, value); | |||||
| } | |||||
| public static _VariableScopeStore get_collection(string key) | |||||
| { | |||||
| return null;// get_default_graph().get_collection(key); | |||||
| } | |||||
| public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
| { | { | ||||
| return tf.Graph(); | return tf.Graph(); | ||||
| @@ -22,7 +22,7 @@ namespace Tensorflow | |||||
| public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) | public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) | ||||
| { | { | ||||
| return new RefVariable(data, name, dtype); | |||||
| return variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); | |||||
| } | } | ||||
| public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | ||||
| @@ -43,18 +43,13 @@ namespace TensorFlowNET.Examples | |||||
| var sub = pred - Y; | var sub = pred - Y; | ||||
| var pow = tf.pow(sub, 2); | var pow = tf.pow(sub, 2); | ||||
| var reduce = tf.reduce_sum(pow); | var reduce = tf.reduce_sum(pow); | ||||
| var cost = reduce / (2d * n_samples); | var cost = reduce / (2d * n_samples); | ||||
| // radient descent | // radient descent | ||||
| // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | ||||
| var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||||
| var optimizer = tf.train.GradientDescentOptimizer(learning_rate); | |||||
| optimizer.minimize(cost); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||