| @@ -46,7 +46,10 @@ namespace Tensorflow.Keras.Utils | |||
| Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); | |||
| var variable_dtype = dtype.as_base_dtype(); | |||
| var v = tf.VariableV1(init_val); | |||
| var v = tf.VariableV1(init_val, | |||
| use_resource: use_resource, | |||
| dtype: dtype, | |||
| shape: shape); | |||
| return v; | |||
| } | |||
| @@ -143,8 +143,8 @@ namespace Tensorflow.Train | |||
| { | |||
| ops.init_scope(); | |||
| var graph = ops.get_default_graph(); | |||
| return (_get_non_slot_variable("beta1_power", graph: graph), | |||
| _get_non_slot_variable("beta2_power", graph: graph)); | |||
| return (_get_non_slot_variable("beta1_power", graph: graph) as RefVariable, | |||
| _get_non_slot_variable("beta2_power", graph: graph) as RefVariable); | |||
| } | |||
| public override void _prepare() | |||
| @@ -44,7 +44,7 @@ namespace Tensorflow | |||
| public Tensor LearningRateTensor => _lr_t; | |||
| public bool _use_locking; | |||
| public Dictionary<string, Dictionary<string, RefVariable>> _slots; | |||
| public Dictionary<string, RefVariable> _non_slot_dict; | |||
| public Dictionary<string, VariableV1> _non_slot_dict; | |||
| public Dictionary<string, object> _deferred_slot_restorations; | |||
| SlotCreator slot_creator = new SlotCreator(); | |||
| @@ -58,7 +58,7 @@ namespace Tensorflow | |||
| _lr = learning_rate; | |||
| // Dictionary of slots. | |||
| _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | |||
| _non_slot_dict = new Dictionary<string, RefVariable>(); | |||
| _non_slot_dict = new Dictionary<string, VariableV1>(); | |||
| _deferred_slot_restorations = new Dictionary<string, object>(); | |||
| } | |||
| @@ -72,7 +72,7 @@ namespace Tensorflow | |||
| _lr_t = learning_rate; | |||
| // Dictionary of slots. | |||
| _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | |||
| _non_slot_dict = new Dictionary<string, RefVariable>(); | |||
| _non_slot_dict = new Dictionary<string, VariableV1>(); | |||
| _deferred_slot_restorations = new Dictionary<string, object>(); | |||
| } | |||
| @@ -239,7 +239,7 @@ namespace Tensorflow | |||
| /// <param name="initial_value"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="colocate_with"></param> | |||
| protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) | |||
| protected VariableV1 _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) | |||
| { | |||
| // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. | |||
| var graph = colocate_with.graph; | |||
| @@ -333,7 +333,7 @@ namespace Tensorflow | |||
| return $"{var.op.graph.graph_key}.{var.op.name}"; | |||
| } | |||
| protected RefVariable _get_non_slot_variable(string name, Graph graph = null) | |||
| protected VariableV1 _get_non_slot_variable(string name, Graph graph = null) | |||
| { | |||
| var key = $"{name}.{graph.graph_key}"; | |||
| var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | |||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Train | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <param name="trackable"></param> | |||
| protected void _handle_deferred_dependencies(string name, RefVariable trackable) | |||
| protected void _handle_deferred_dependencies(string name, VariableV1 trackable) | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| // TODO | |||
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -53,7 +54,8 @@ namespace Tensorflow | |||
| string name = null, | |||
| VariableDef variable_def = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string import_scope = "") : base(initial_value, | |||
| string import_scope = "", | |||
| TensorShape shape = null) : base(initial_value, | |||
| trainable, | |||
| collections, | |||
| validate_shape, | |||
| @@ -69,11 +71,31 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("ResourceVariable _init_from_args"); | |||
| //_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); | |||
| _init_from_args(initial_value: initial_value, | |||
| trainable: trainable, | |||
| collections: collections, | |||
| caching_device: caching_device, | |||
| name: name, | |||
| dtype: dtype, | |||
| shape: shape); | |||
| } | |||
| } | |||
| private void _init_from_args(object initial_value = null, | |||
| bool trainable = true, | |||
| List<string> collections = null, | |||
| string caching_device = "", | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| TensorShape shape = null) | |||
| { | |||
| var init_from_fn = initial_value.GetType().Name == "Func`1"; | |||
| if(collections == null) | |||
| collections = new List<string>() { tf.GraphKeys.GLOBAL_VARIABLES }; | |||
| throw new NotImplementedException(""); | |||
| } | |||
| private void _init_from_proto(VariableDef variable_def, string import_scope = null) | |||
| { | |||
| _in_graph_mode = true; | |||
| @@ -71,7 +71,7 @@ namespace Tensorflow | |||
| trainable: trainable, | |||
| collections: collections, | |||
| synchronization: synchronization, | |||
| aggregation: aggregation); | |||
| aggregation: aggregation) as RefVariable; | |||
| }); | |||
| } | |||
| } | |||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||
| /// the variable are fixed. The value can be changed using one of the assign methods. | |||
| /// https://tensorflow.org/guide/variables | |||
| /// </summary> | |||
| public class VariableV1 | |||
| public abstract class VariableV1 | |||
| { | |||
| public virtual string name { get; } | |||
| public virtual Tensor graph_element { get; } | |||
| @@ -36,7 +36,7 @@ namespace Tensorflow | |||
| _store_eager_variables = false; | |||
| } | |||
| public RefVariable get_variable(string name, | |||
| public VariableV1 get_variable(string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| object initializer = null, // IInitializer or Tensor | |||
| @@ -61,7 +61,7 @@ namespace Tensorflow | |||
| aggregation: aggregation); | |||
| } | |||
| private RefVariable _true_getter(string name, | |||
| private VariableV1 _true_getter(string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| object initializer = null, | |||
| @@ -110,7 +110,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private RefVariable _get_single_variable(string name, | |||
| private VariableV1 _get_single_variable(string name, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| @@ -136,7 +136,7 @@ namespace Tensorflow | |||
| throw new NotImplementedException("_get_single_variable"); | |||
| } | |||
| RefVariable v = null; | |||
| VariableV1 v = null; | |||
| // Create the tensor to initialize the variable with default value. | |||
| if (initializer == null) | |||
| { | |||
| @@ -172,11 +172,12 @@ namespace Tensorflow | |||
| return $"{prefix}_{idx}"; | |||
| } | |||
| public static RefVariable default_variable_creator(object initial_value, | |||
| public static VariableV1 default_variable_creator(object initial_value, | |||
| string name = null, | |||
| bool? trainable = null, | |||
| List<string> collections = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int[] shape = null, | |||
| bool validate_shape = false, | |||
| bool ? use_resource = null, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| @@ -193,7 +194,13 @@ namespace Tensorflow | |||
| if (use_resource.Value) | |||
| { | |||
| throw new NotImplementedException(); | |||
| return new ResourceVariable(initial_value, | |||
| trainable: trainable.Value, | |||
| validate_shape: validate_shape, | |||
| collections: collections, | |||
| name: name, | |||
| dtype: dtype, | |||
| shape: shape); | |||
| } | |||
| else | |||
| { | |||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||
| trainable: trainable, | |||
| validate_shape: validate_shape, | |||
| name: name, | |||
| dtype: dtype); | |||
| dtype: dtype) as RefVariable; | |||
| } | |||
| public VariableV1 VariableV1<T>(T data, | |||
| @@ -63,14 +63,16 @@ namespace Tensorflow | |||
| bool validate_shape = true, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool use_resource = false) | |||
| bool use_resource = false, | |||
| int[] shape = null) | |||
| { | |||
| return Tensorflow.variable_scope.default_variable_creator(data, | |||
| trainable: trainable, | |||
| validate_shape: validate_shape, | |||
| name: name, | |||
| dtype: dtype, | |||
| use_resource: use_resource); | |||
| use_resource: use_resource, | |||
| shape: shape); | |||
| } | |||
| public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | |||