| @@ -0,0 +1,12 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public interface IInitializer | |||||
| { | |||||
| Tensor call(TensorShape shape, TF_DataType dtype); | |||||
| object get_config(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,34 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static IInitializer zeros_initializer => new Zeros(); | |||||
| public class Zeros : IInitializer | |||||
| { | |||||
| private TF_DataType dtype; | |||||
| public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| { | |||||
| this.dtype = dtype; | |||||
| } | |||||
| public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = this.dtype; | |||||
| return array_ops.zeros(shape, dtype); | |||||
| } | |||||
| public object get_config() | |||||
| { | |||||
| return new { dtype = dtype.name() }; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -71,6 +71,11 @@ namespace Tensorflow | |||||
| type; | type; | ||||
| } | } | ||||
| public static int name(this TF_DataType type) | |||||
| { | |||||
| return (int)type; | |||||
| } | |||||
| public static DataType as_base_dtype(this DataType type) | public static DataType as_base_dtype(this DataType type) | ||||
| { | { | ||||
| return (int)type > 100 ? | return (int)type > 100 ? | ||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public enum VariableAggregation | |||||
| { | |||||
| NONE = 0, | |||||
| SUM = 1, | |||||
| MEAN = 2, | |||||
| ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER | |||||
| } | |||||
| } | |||||
| @@ -6,6 +6,35 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class VariableScope | public class VariableScope | ||||
| { | { | ||||
| public bool? use_resource { get; set; } | |||||
| public bool use_resource { get; set; } | |||||
| private _ReuseMode _reuse { get; set; } | |||||
| private object _regularizer; | |||||
| private TF_DataType _dtype; | |||||
| public string name { get; set; } | |||||
| public VariableScope() | |||||
| { | |||||
| _reuse = _ReuseMode.AUTO_REUSE; | |||||
| } | |||||
| public RefVariable get_variable(_VariableStore var_store, | |||||
| string name, | |||||
| TensorShape shape = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
| VariableAggregation aggregation= VariableAggregation.NONE) | |||||
| { | |||||
| string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope => | |||||
| { | |||||
| if (dtype == TF_DataType.DtInvalid) | |||||
| dtype = _dtype; | |||||
| return var_store.get_variable(full_name); | |||||
| }); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Mode for variable access within a variable scope. | |||||
| /// </summary> | |||||
| public enum _ReuseMode | |||||
| { | |||||
| // Indicates that variables are to be fetched if they already exist or | |||||
| // otherwise created. | |||||
| AUTO_REUSE = 1 | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,80 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Variable store that carries a number of named Variables. | |||||
| /// </summary> | |||||
| public class _VariableStore | |||||
| { | |||||
| private Dictionary<string, object> _vars; | |||||
| private Dictionary<string, object> _partitioned_vars; | |||||
| private bool _store_eager_variables; | |||||
| public _VariableStore() | |||||
| { | |||||
| _vars = new Dictionary<string, object>(); | |||||
| _partitioned_vars = new Dictionary<string, object>(); | |||||
| _store_eager_variables = false; | |||||
| } | |||||
| public RefVariable get_variable(string name, | |||||
| TensorShape shape = null, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| IInitializer initializer = null, | |||||
| bool trainable = false, | |||||
| bool validate_shape = true, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||||
| { | |||||
| dtype = dtype.as_base_dtype(); | |||||
| trainable = variable_scope._get_trainable_value(synchronization, trainable); | |||||
| return _true_getter(name, | |||||
| shape: shape, | |||||
| dtype: dtype, | |||||
| initializer: initializer, | |||||
| trainable: trainable, | |||||
| validate_shape: validate_shape, | |||||
| synchronization: synchronization, | |||||
| aggregation: aggregation); | |||||
| } | |||||
| private RefVariable _true_getter(string name, | |||||
| TensorShape shape = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| IInitializer initializer = null, | |||||
| bool trainable = false, | |||||
| bool validate_shape = true, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||||
| { | |||||
| return _get_single_variable(name: name); | |||||
| } | |||||
| private RefVariable _get_single_variable(string name, | |||||
| TensorShape shape = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| IInitializer initializer = null, | |||||
| bool reuse = false, | |||||
| bool trainable = false, | |||||
| bool validate_shape = false, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||||
| { | |||||
| if (_vars.ContainsKey(name)) | |||||
| { | |||||
| if (!reuse) | |||||
| { | |||||
| var var = _vars[name]; | |||||
| } | |||||
| throw new NotImplementedException("_get_single_variable"); | |||||
| } | |||||
| throw new NotImplementedException("_get_single_variable"); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -11,5 +11,15 @@ namespace Tensorflow | |||||
| var g = variables.global_variables(); | var g = variables.global_variables(); | ||||
| return variables.variables_initializer(g.ToArray()); | return variables.variables_initializer(g.ToArray()); | ||||
| } | } | ||||
| public static RefVariable get_variable(string name, | |||||
| TensorShape shape = null, | |||||
| IInitializer initializer = null, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||||
| { | |||||
| var store = variable_scope._get_default_variable_store(); | |||||
| return variable_scope.get_variable_scope().get_variable(store, name, shape: shape); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,6 +6,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class variable_scope | public class variable_scope | ||||
| { | { | ||||
| public static string _VARSTORE_KEY = "__variable_store"; | |||||
| public static string _VARSCOPESTORE_KEY = "__varscope"; | public static string _VARSCOPESTORE_KEY = "__varscope"; | ||||
| public static bool _DEFAULT_USE_RESOURCE = false; | public static bool _DEFAULT_USE_RESOURCE = false; | ||||
| @@ -32,6 +33,17 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public static _VariableStore _get_default_variable_store() | |||||
| { | |||||
| var store = ops.get_collection(_VARSTORE_KEY); | |||||
| if (store != null) | |||||
| return (store as List<_VariableStore>)[0]; | |||||
| var store1 = new _VariableStore(); | |||||
| ops.add_to_collection(_VARSTORE_KEY, store1); | |||||
| return store1; | |||||
| } | |||||
| public static VariableScope get_variable_scope() | public static VariableScope get_variable_scope() | ||||
| { | { | ||||
| return get_variable_scope_store().current_scope; | return get_variable_scope_store().current_scope; | ||||
| @@ -65,24 +77,18 @@ namespace Tensorflow | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | |||||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true) | |||||
| { | { | ||||
| if(synchronization == VariableSynchronization.ON_READ) | |||||
| if (synchronization == VariableSynchronization.ON_READ) | |||||
| { | { | ||||
| if (trainable.Value) | |||||
| if (trainable) | |||||
| throw new ValueError("Synchronization value can be set to " + | throw new ValueError("Synchronization value can be set to " + | ||||
| "VariableSynchronization.ON_READ only for non-trainable variables. " + | "VariableSynchronization.ON_READ only for non-trainable variables. " + | ||||
| "You have specified trainable=True and " + | "You have specified trainable=True and " + | ||||
| "synchronization=VariableSynchronization.ON_READ."); | "synchronization=VariableSynchronization.ON_READ."); | ||||
| else | |||||
| trainable = false; | |||||
| } | } | ||||
| else if (!trainable.HasValue) | |||||
| { | |||||
| trainable = true; | |||||
| } | |||||
| return trainable.Value; | |||||
| return trainable; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,21 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class TrainSaverTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void Save() | |||||
| { | |||||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||||
| var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| | |||||
| import tensorflow as tf | |||||
| # Create some variables. | |||||
| v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) | |||||
| v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) | |||||
| inc_v1 = v1.assign(v1+1) | |||||
| dec_v2 = v2.assign(v2-1) | |||||
| # Add an op to initialize the variables. | |||||
| init_op = tf.global_variables_initializer() | |||||
| # Add ops to save and restore all the variables. | |||||
| saver = tf.train.Saver() | |||||
| # Later, launch the model, initialize the variables, do some work, and save the | |||||
| # variables to disk. | |||||
| with tf.Session() as sess: | |||||
| sess.run(init_op) | |||||
| # Do some work with the model. | |||||
| inc_v1.op.run() | |||||
| dec_v2.op.run() | |||||
| # Save the variables to disk. | |||||
| save_path = saver.save(sess, "/tmp/model.ckpt") | |||||
| print("Model saved in path: %s" % save_path) | |||||