| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public interface IInitializer | |||
| { | |||
| Tensor call(TensorShape shape, TF_DataType dtype); | |||
| object get_config(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,34 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public static partial class tf | |||
| { | |||
| public static IInitializer zeros_initializer => new Zeros(); | |||
| public class Zeros : IInitializer | |||
| { | |||
| private TF_DataType dtype; | |||
| public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
| { | |||
| this.dtype = dtype; | |||
| } | |||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = this.dtype; | |||
| return array_ops.zeros(shape, dtype); | |||
| } | |||
| public object get_config() | |||
| { | |||
| return new { dtype = dtype.name() }; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -71,6 +71,11 @@ namespace Tensorflow | |||
| type; | |||
| } | |||
| public static int name(this TF_DataType type) | |||
| { | |||
| return (int)type; | |||
| } | |||
| public static DataType as_base_dtype(this DataType type) | |||
| { | |||
| return (int)type > 100 ? | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public enum VariableAggregation | |||
| { | |||
| NONE = 0, | |||
| SUM = 1, | |||
| MEAN = 2, | |||
| ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER | |||
| } | |||
| } | |||
| @@ -6,6 +6,35 @@ namespace Tensorflow | |||
| { | |||
| public class VariableScope | |||
| { | |||
| public bool? use_resource { get; set; } | |||
| public bool use_resource { get; set; } | |||
| private _ReuseMode _reuse { get; set; } | |||
| private object _regularizer; | |||
| private TF_DataType _dtype; | |||
| public string name { get; set; } | |||
| public VariableScope() | |||
| { | |||
| _reuse = _ReuseMode.AUTO_REUSE; | |||
| } | |||
| public RefVariable get_variable(_VariableStore var_store, | |||
| string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
| VariableAggregation aggregation= VariableAggregation.NONE) | |||
| { | |||
| string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | |||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope => | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = _dtype; | |||
| return var_store.get_variable(full_name); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,16 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Mode for variable access within a variable scope. | |||
| /// </summary> | |||
| public enum _ReuseMode | |||
| { | |||
| // Indicates that variables are to be fetched if they already exist or | |||
| // otherwise created. | |||
| AUTO_REUSE = 1 | |||
| } | |||
| } | |||
| @@ -0,0 +1,80 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Variable store that carries a number of named Variables. | |||
| /// </summary> | |||
| public class _VariableStore | |||
| { | |||
| private Dictionary<string, object> _vars; | |||
| private Dictionary<string, object> _partitioned_vars; | |||
| private bool _store_eager_variables; | |||
| public _VariableStore() | |||
| { | |||
| _vars = new Dictionary<string, object>(); | |||
| _partitioned_vars = new Dictionary<string, object>(); | |||
| _store_eager_variables = false; | |||
| } | |||
| public RefVariable get_variable(string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| IInitializer initializer = null, | |||
| bool trainable = false, | |||
| bool validate_shape = true, | |||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||
| { | |||
| dtype = dtype.as_base_dtype(); | |||
| trainable = variable_scope._get_trainable_value(synchronization, trainable); | |||
| return _true_getter(name, | |||
| shape: shape, | |||
| dtype: dtype, | |||
| initializer: initializer, | |||
| trainable: trainable, | |||
| validate_shape: validate_shape, | |||
| synchronization: synchronization, | |||
| aggregation: aggregation); | |||
| } | |||
| private RefVariable _true_getter(string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| bool trainable = false, | |||
| bool validate_shape = true, | |||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||
| { | |||
| return _get_single_variable(name: name); | |||
| } | |||
| private RefVariable _get_single_variable(string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| bool reuse = false, | |||
| bool trainable = false, | |||
| bool validate_shape = false, | |||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||
| { | |||
| if (_vars.ContainsKey(name)) | |||
| { | |||
| if (!reuse) | |||
| { | |||
| var var = _vars[name]; | |||
| } | |||
| throw new NotImplementedException("_get_single_variable"); | |||
| } | |||
| throw new NotImplementedException("_get_single_variable"); | |||
| } | |||
| } | |||
| } | |||
| @@ -11,5 +11,15 @@ namespace Tensorflow | |||
| var g = variables.global_variables(); | |||
| return variables.variables_initializer(g.ToArray()); | |||
| } | |||
| public static RefVariable get_variable(string name, | |||
| TensorShape shape = null, | |||
| IInitializer initializer = null, | |||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||
| { | |||
| var store = variable_scope._get_default_variable_store(); | |||
| return variable_scope.get_variable_scope().get_variable(store, name, shape: shape); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,7 @@ namespace Tensorflow | |||
| { | |||
| public class variable_scope | |||
| { | |||
| public static string _VARSTORE_KEY = "__variable_store"; | |||
| public static string _VARSCOPESTORE_KEY = "__varscope"; | |||
| public static bool _DEFAULT_USE_RESOURCE = false; | |||
| @@ -32,6 +33,17 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public static _VariableStore _get_default_variable_store() | |||
| { | |||
| var store = ops.get_collection(_VARSTORE_KEY); | |||
| if (store != null) | |||
| return (store as List<_VariableStore>)[0]; | |||
| var store1 = new _VariableStore(); | |||
| ops.add_to_collection(_VARSTORE_KEY, store1); | |||
| return store1; | |||
| } | |||
| public static VariableScope get_variable_scope() | |||
| { | |||
| return get_variable_scope_store().current_scope; | |||
| @@ -65,24 +77,18 @@ namespace Tensorflow | |||
| return ret; | |||
| } | |||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | |||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true) | |||
| { | |||
| if(synchronization == VariableSynchronization.ON_READ) | |||
| if (synchronization == VariableSynchronization.ON_READ) | |||
| { | |||
| if (trainable.Value) | |||
| if (trainable) | |||
| throw new ValueError("Synchronization value can be set to " + | |||
| "VariableSynchronization.ON_READ only for non-trainable variables. " + | |||
| "You have specified trainable=True and " + | |||
| "synchronization=VariableSynchronization.ON_READ."); | |||
| else | |||
| trainable = false; | |||
| } | |||
| else if (!trainable.HasValue) | |||
| { | |||
| trainable = true; | |||
| } | |||
| return trainable.Value; | |||
| return trainable; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,21 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class TrainSaverTest | |||
| { | |||
| [TestMethod] | |||
| public void Save() | |||
| { | |||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||
| var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| | |||
| import tensorflow as tf | |||
| # Create some variables. | |||
| v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) | |||
| v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) | |||
| inc_v1 = v1.assign(v1+1) | |||
| dec_v2 = v2.assign(v2-1) | |||
| # Add an op to initialize the variables. | |||
| init_op = tf.global_variables_initializer() | |||
| # Add ops to save and restore all the variables. | |||
| saver = tf.train.Saver() | |||
| # Later, launch the model, initialize the variables, do some work, and save the | |||
| # variables to disk. | |||
| with tf.Session() as sess: | |||
| sess.run(init_op) | |||
| # Do some work with the model. | |||
| inc_v1.op.run() | |||
| dec_v2.op.run() | |||
| # Save the variables to disk. | |||
| save_path = saver.save(sess, "/tmp/model.ckpt") | |||
| print("Model saved in path: %s" % save_path) | |||