| @@ -46,7 +46,10 @@ namespace Tensorflow.Keras.Utils | |||||
| Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); | Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); | ||||
| var variable_dtype = dtype.as_base_dtype(); | var variable_dtype = dtype.as_base_dtype(); | ||||
| var v = tf.VariableV1(init_val); | |||||
| var v = tf.VariableV1(init_val, | |||||
| use_resource: use_resource, | |||||
| dtype: dtype, | |||||
| shape: shape); | |||||
| return v; | return v; | ||||
| } | } | ||||
| @@ -143,8 +143,8 @@ namespace Tensorflow.Train | |||||
| { | { | ||||
| ops.init_scope(); | ops.init_scope(); | ||||
| var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
| return (_get_non_slot_variable("beta1_power", graph: graph), | |||||
| _get_non_slot_variable("beta2_power", graph: graph)); | |||||
| return (_get_non_slot_variable("beta1_power", graph: graph) as RefVariable, | |||||
| _get_non_slot_variable("beta2_power", graph: graph) as RefVariable); | |||||
| } | } | ||||
| public override void _prepare() | public override void _prepare() | ||||
| @@ -44,7 +44,7 @@ namespace Tensorflow | |||||
| public Tensor LearningRateTensor => _lr_t; | public Tensor LearningRateTensor => _lr_t; | ||||
| public bool _use_locking; | public bool _use_locking; | ||||
| public Dictionary<string, Dictionary<string, RefVariable>> _slots; | public Dictionary<string, Dictionary<string, RefVariable>> _slots; | ||||
| public Dictionary<string, RefVariable> _non_slot_dict; | |||||
| public Dictionary<string, VariableV1> _non_slot_dict; | |||||
| public Dictionary<string, object> _deferred_slot_restorations; | public Dictionary<string, object> _deferred_slot_restorations; | ||||
| SlotCreator slot_creator = new SlotCreator(); | SlotCreator slot_creator = new SlotCreator(); | ||||
| @@ -58,7 +58,7 @@ namespace Tensorflow | |||||
| _lr = learning_rate; | _lr = learning_rate; | ||||
| // Dictionary of slots. | // Dictionary of slots. | ||||
| _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | ||||
| _non_slot_dict = new Dictionary<string, RefVariable>(); | |||||
| _non_slot_dict = new Dictionary<string, VariableV1>(); | |||||
| _deferred_slot_restorations = new Dictionary<string, object>(); | _deferred_slot_restorations = new Dictionary<string, object>(); | ||||
| } | } | ||||
| @@ -72,7 +72,7 @@ namespace Tensorflow | |||||
| _lr_t = learning_rate; | _lr_t = learning_rate; | ||||
| // Dictionary of slots. | // Dictionary of slots. | ||||
| _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | _slots = new Dictionary<string, Dictionary<string, RefVariable>>(); | ||||
| _non_slot_dict = new Dictionary<string, RefVariable>(); | |||||
| _non_slot_dict = new Dictionary<string, VariableV1>(); | |||||
| _deferred_slot_restorations = new Dictionary<string, object>(); | _deferred_slot_restorations = new Dictionary<string, object>(); | ||||
| } | } | ||||
| @@ -239,7 +239,7 @@ namespace Tensorflow | |||||
| /// <param name="initial_value"></param> | /// <param name="initial_value"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="colocate_with"></param> | /// <param name="colocate_with"></param> | ||||
| protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) | |||||
| protected VariableV1 _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) | |||||
| { | { | ||||
| // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. | // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. | ||||
| var graph = colocate_with.graph; | var graph = colocate_with.graph; | ||||
| @@ -333,7 +333,7 @@ namespace Tensorflow | |||||
| return $"{var.op.graph.graph_key}.{var.op.name}"; | return $"{var.op.graph.graph_key}.{var.op.name}"; | ||||
| } | } | ||||
| protected RefVariable _get_non_slot_variable(string name, Graph graph = null) | |||||
| protected VariableV1 _get_non_slot_variable(string name, Graph graph = null) | |||||
| { | { | ||||
| var key = $"{name}.{graph.graph_key}"; | var key = $"{name}.{graph.graph_key}"; | ||||
| var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | ||||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Train | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="trackable"></param> | /// <param name="trackable"></param> | ||||
| protected void _handle_deferred_dependencies(string name, RefVariable trackable) | |||||
| protected void _handle_deferred_dependencies(string name, VariableV1 trackable) | |||||
| { | { | ||||
| _maybe_initialize_trackable(); | _maybe_initialize_trackable(); | ||||
| // TODO | // TODO | ||||
| @@ -16,6 +16,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -53,7 +54,8 @@ namespace Tensorflow | |||||
| string name = null, | string name = null, | ||||
| VariableDef variable_def = null, | VariableDef variable_def = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| string import_scope = "") : base(initial_value, | |||||
| string import_scope = "", | |||||
| TensorShape shape = null) : base(initial_value, | |||||
| trainable, | trainable, | ||||
| collections, | collections, | ||||
| validate_shape, | validate_shape, | ||||
| @@ -69,11 +71,31 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| throw new NotImplementedException("ResourceVariable _init_from_args"); | |||||
| //_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); | |||||
| _init_from_args(initial_value: initial_value, | |||||
| trainable: trainable, | |||||
| collections: collections, | |||||
| caching_device: caching_device, | |||||
| name: name, | |||||
| dtype: dtype, | |||||
| shape: shape); | |||||
| } | } | ||||
| } | } | ||||
| private void _init_from_args(object initial_value = null, | |||||
| bool trainable = true, | |||||
| List<string> collections = null, | |||||
| string caching_device = "", | |||||
| string name = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| TensorShape shape = null) | |||||
| { | |||||
| var init_from_fn = initial_value.GetType().Name == "Func`1"; | |||||
| if(collections == null) | |||||
| collections = new List<string>() { tf.GraphKeys.GLOBAL_VARIABLES }; | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| private void _init_from_proto(VariableDef variable_def, string import_scope = null) | private void _init_from_proto(VariableDef variable_def, string import_scope = null) | ||||
| { | { | ||||
| _in_graph_mode = true; | _in_graph_mode = true; | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow | |||||
| trainable: trainable, | trainable: trainable, | ||||
| collections: collections, | collections: collections, | ||||
| synchronization: synchronization, | synchronization: synchronization, | ||||
| aggregation: aggregation); | |||||
| aggregation: aggregation) as RefVariable; | |||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||||
| /// the variable are fixed. The value can be changed using one of the assign methods. | /// the variable are fixed. The value can be changed using one of the assign methods. | ||||
| /// https://tensorflow.org/guide/variables | /// https://tensorflow.org/guide/variables | ||||
| /// </summary> | /// </summary> | ||||
| public class VariableV1 | |||||
| public abstract class VariableV1 | |||||
| { | { | ||||
| public virtual string name { get; } | public virtual string name { get; } | ||||
| public virtual Tensor graph_element { get; } | public virtual Tensor graph_element { get; } | ||||
| @@ -36,7 +36,7 @@ namespace Tensorflow | |||||
| _store_eager_variables = false; | _store_eager_variables = false; | ||||
| } | } | ||||
| public RefVariable get_variable(string name, | |||||
| public VariableV1 get_variable(string name, | |||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| object initializer = null, // IInitializer or Tensor | object initializer = null, // IInitializer or Tensor | ||||
| @@ -61,7 +61,7 @@ namespace Tensorflow | |||||
| aggregation: aggregation); | aggregation: aggregation); | ||||
| } | } | ||||
| private RefVariable _true_getter(string name, | |||||
| private VariableV1 _true_getter(string name, | |||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| object initializer = null, | object initializer = null, | ||||
| @@ -110,7 +110,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private RefVariable _get_single_variable(string name, | |||||
| private VariableV1 _get_single_variable(string name, | |||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| @@ -136,7 +136,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("_get_single_variable"); | throw new NotImplementedException("_get_single_variable"); | ||||
| } | } | ||||
| RefVariable v = null; | |||||
| VariableV1 v = null; | |||||
| // Create the tensor to initialize the variable with default value. | // Create the tensor to initialize the variable with default value. | ||||
| if (initializer == null) | if (initializer == null) | ||||
| { | { | ||||
| @@ -172,11 +172,12 @@ namespace Tensorflow | |||||
| return $"{prefix}_{idx}"; | return $"{prefix}_{idx}"; | ||||
| } | } | ||||
| public static RefVariable default_variable_creator(object initial_value, | |||||
| public static VariableV1 default_variable_creator(object initial_value, | |||||
| string name = null, | string name = null, | ||||
| bool? trainable = null, | bool? trainable = null, | ||||
| List<string> collections = null, | List<string> collections = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int[] shape = null, | |||||
| bool validate_shape = false, | bool validate_shape = false, | ||||
| bool ? use_resource = null, | bool ? use_resource = null, | ||||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | VariableSynchronization synchronization = VariableSynchronization.Auto, | ||||
| @@ -193,7 +194,13 @@ namespace Tensorflow | |||||
| if (use_resource.Value) | if (use_resource.Value) | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| return new ResourceVariable(initial_value, | |||||
| trainable: trainable.Value, | |||||
| validate_shape: validate_shape, | |||||
| collections: collections, | |||||
| name: name, | |||||
| dtype: dtype, | |||||
| shape: shape); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||||
| trainable: trainable, | trainable: trainable, | ||||
| validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
| name: name, | name: name, | ||||
| dtype: dtype); | |||||
| dtype: dtype) as RefVariable; | |||||
| } | } | ||||
| public VariableV1 VariableV1<T>(T data, | public VariableV1 VariableV1<T>(T data, | ||||
| @@ -63,14 +63,16 @@ namespace Tensorflow | |||||
| bool validate_shape = true, | bool validate_shape = true, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| bool use_resource = false) | |||||
| bool use_resource = false, | |||||
| int[] shape = null) | |||||
| { | { | ||||
| return Tensorflow.variable_scope.default_variable_creator(data, | return Tensorflow.variable_scope.default_variable_creator(data, | ||||
| trainable: trainable, | trainable: trainable, | ||||
| validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
| name: name, | name: name, | ||||
| dtype: dtype, | dtype: dtype, | ||||
| use_resource: use_resource); | |||||
| use_resource: use_resource, | |||||
| shape: shape); | |||||
| } | } | ||||
| public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | ||||