| @@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| bool? trainable = null, | bool? trainable = null, | ||||
| Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null) | |||||
| Func<VariableArgs, IVariableV1> getter = null) | |||||
| { | { | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = TF_DataType.TF_FLOAT; | dtype = TF_DataType.TF_FLOAT; | ||||
| @@ -259,7 +259,7 @@ namespace Tensorflow.Keras.Engine | |||||
| trainable = true; | trainable = true; | ||||
| // Initialize variable when no initializer provided | // Initialize variable when no initializer provided | ||||
| if(initializer == null) | |||||
| if (initializer == null) | |||||
| { | { | ||||
| // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | ||||
| if (dtype.is_floating()) | if (dtype.is_floating()) | ||||
| @@ -269,13 +269,18 @@ namespace Tensorflow.Keras.Engine | |||||
| else | else | ||||
| throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); | throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); | ||||
| } | } | ||||
| var variable = _add_variable_with_custom_getter(name, | |||||
| shape, | |||||
| dtype: dtype, | |||||
| getter: (getter == null) ? base_layer_utils.make_variable : getter, | |||||
| overwrite: true, | |||||
| initializer: initializer, | |||||
| trainable: trainable.Value); | |||||
| var variable = _add_variable_with_custom_getter(new VariableArgs | |||||
| { | |||||
| Name = name, | |||||
| Shape = shape, | |||||
| DType = dtype, | |||||
| Getter = getter ?? base_layer_utils.make_variable, | |||||
| Overwrite = true, | |||||
| Initializer = initializer, | |||||
| Trainable = trainable.Value | |||||
| }); | |||||
| //backend.track_variable(variable); | //backend.track_variable(variable); | ||||
| if (trainable == true) | if (trainable == true) | ||||
| _trainable_weights.Add(variable); | _trainable_weights.Add(variable); | ||||
| @@ -199,8 +199,8 @@ namespace Tensorflow.Keras.Optimizers | |||||
| } | } | ||||
| } | } | ||||
| ResourceVariable add_weight(string name, | |||||
| TensorShape shape, | |||||
| ResourceVariable add_weight(string name, | |||||
| TensorShape shape, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| bool trainable = false, | bool trainable = false, | ||||
| @@ -213,16 +213,19 @@ namespace Tensorflow.Keras.Optimizers | |||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = TF_DataType.TF_FLOAT; | dtype = TF_DataType.TF_FLOAT; | ||||
| var variable = _add_variable_with_custom_getter(name: name, | |||||
| shape: shape, | |||||
| getter: base_layer_utils.make_variable, | |||||
| dtype: dtype, | |||||
| overwrite: true, | |||||
| initializer: initializer, | |||||
| trainable: trainable, | |||||
| use_resource: true, | |||||
| synchronization: synchronization, | |||||
| aggregation: aggregation); | |||||
| var variable = _add_variable_with_custom_getter(new VariableArgs | |||||
| { | |||||
| Name = name, | |||||
| Shape = shape, | |||||
| Getter = base_layer_utils.make_variable, | |||||
| DType = dtype, | |||||
| Overwrite = true, | |||||
| Initializer = initializer, | |||||
| Trainable = trainable, | |||||
| UseResource = true, | |||||
| Synchronization = synchronization, | |||||
| Aggregation = aggregation | |||||
| }); | |||||
| return variable as ResourceVariable; | return variable as ResourceVariable; | ||||
| } | } | ||||
| @@ -26,32 +26,26 @@ namespace Tensorflow.Keras.Utils | |||||
| /// <summary> | /// <summary> | ||||
| /// Adds a new variable to the layer. | /// Adds a new variable to the layer. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="name"></param> | |||||
| /// <param name="shape"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <param name="initializer"></param> | |||||
| /// <param name="trainable"></param> | |||||
| /// <param name="args"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static IVariableV1 make_variable(string name, | |||||
| int[] shape, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| IInitializer initializer = null, | |||||
| bool trainable = true) | |||||
| public static IVariableV1 make_variable(VariableArgs args) | |||||
| { | { | ||||
| #pragma warning disable CS0219 // Variable is assigned but its value is never used | #pragma warning disable CS0219 // Variable is assigned but its value is never used | ||||
| var initializing_from_value = false; | var initializing_from_value = false; | ||||
| bool use_resource = true; | |||||
| #pragma warning restore CS0219 // Variable is assigned but its value is never used | #pragma warning restore CS0219 // Variable is assigned but its value is never used | ||||
| ops.init_scope(); | ops.init_scope(); | ||||
| Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); | |||||
| Func<Tensor> init_val = () => args.Initializer.call(args.Shape, dtype: args.DType); | |||||
| var variable_dtype = dtype.as_base_dtype(); | |||||
| var variable_dtype = args.DType.as_base_dtype(); | |||||
| var v = tf.Variable(init_val, | var v = tf.Variable(init_val, | ||||
| dtype: dtype, | |||||
| shape: shape, | |||||
| name: name); | |||||
| dtype: args.DType, | |||||
| shape: args.Shape, | |||||
| name: args.Name, | |||||
| trainable: args.Trainable, | |||||
| validate_shape: args.ValidateShape, | |||||
| use_resource: args.UseResource); | |||||
| return v; | return v; | ||||
| } | } | ||||
| @@ -167,12 +167,12 @@ namespace Tensorflow.Layers | |||||
| dtype: dtype, | dtype: dtype, | ||||
| initializer: initializer, | initializer: initializer, | ||||
| trainable: trainable, | trainable: trainable, | ||||
| getter: (name1, shape1, dtype1, initializer1, trainable1) => | |||||
| tf.compat.v1.get_variable(name1, | |||||
| shape: new TensorShape(shape1), | |||||
| dtype: dtype1, | |||||
| initializer: initializer1, | |||||
| trainable: trainable1) | |||||
| getter: (args) => | |||||
| tf.compat.v1.get_variable(args.Name, | |||||
| shape: args.Shape, | |||||
| dtype: args.DType, | |||||
| initializer: args.Initializer, | |||||
| trainable: args.Trainable) | |||||
| ); | ); | ||||
| //if (init_graph != null) | //if (init_graph != null) | ||||
| @@ -27,16 +27,7 @@ namespace Tensorflow.Train | |||||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected virtual IVariableV1 _add_variable_with_custom_getter(string name, | |||||
| int[] shape, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| IInitializer initializer = null, | |||||
| Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null, | |||||
| bool overwrite = false, | |||||
| bool trainable = false, | |||||
| bool use_resource = false, | |||||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||||
| VariableAggregation aggregation = VariableAggregation.None) | |||||
| protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args) | |||||
| { | { | ||||
| ops.init_scope(); | ops.init_scope(); | ||||
| #pragma warning disable CS0219 // Variable is assigned but its value is never used | #pragma warning disable CS0219 // Variable is assigned but its value is never used | ||||
| @@ -50,15 +41,15 @@ namespace Tensorflow.Train | |||||
| checkpoint_initializer = null; | checkpoint_initializer = null; | ||||
| IVariableV1 new_variable; | IVariableV1 new_variable; | ||||
| new_variable = getter(name, shape, dtype, initializer, trainable); | |||||
| new_variable = args.Getter(args); | |||||
| // If we set an initializer and the variable processed it, tracking will not | // If we set an initializer and the variable processed it, tracking will not | ||||
| // assign again. It will add this variable to our dependencies, and if there | // assign again. It will add this variable to our dependencies, and if there | ||||
| // is a non-trivial restoration queued, it will handle that. This also | // is a non-trivial restoration queued, it will handle that. This also | ||||
| // handles slot variables. | // handles slot variables. | ||||
| if (!overwrite || new_variable is RefVariable) | |||||
| return _track_checkpointable(new_variable, name: name, | |||||
| overwrite: overwrite); | |||||
| if (!args.Overwrite || new_variable is RefVariable) | |||||
| return _track_checkpointable(new_variable, name: args.Name, | |||||
| overwrite: args.Overwrite); | |||||
| else | else | ||||
| return new_variable; | return new_variable; | ||||
| } | } | ||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class VariableArgs | |||||
| { | |||||
| public object InitialValue { get; set; } | |||||
| public Func<VariableArgs, IVariableV1> Getter { get; set; } | |||||
| public string Name { get; set; } | |||||
| public TensorShape Shape { get; set; } | |||||
| public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; | |||||
| public IInitializer Initializer { get; set; } | |||||
| public bool Trainable { get; set; } | |||||
| public bool ValidateShape { get; set; } = true; | |||||
| public bool UseResource { get; set; } = true; | |||||
| public bool Overwrite { get; set; } | |||||
| public List<string> Collections { get; set; } | |||||
| public string CachingDevice { get; set; } = ""; | |||||
| public VariableDef VariableDef { get; set; } | |||||
| public string ImportScope { get; set; } = ""; | |||||
| public VariableSynchronization Synchronization { get; set; } = VariableSynchronization.Auto; | |||||
| public VariableAggregation Aggregation { get; set; } = VariableAggregation.None; | |||||
| } | |||||
| } | |||||
| @@ -62,6 +62,7 @@ namespace Tensorflow | |||||
| public ResourceVariable Variable<T>(T data, | public ResourceVariable Variable<T>(T data, | ||||
| bool trainable = true, | bool trainable = true, | ||||
| bool validate_shape = true, | bool validate_shape = true, | ||||
| bool use_resource = true, | |||||
| string name = null, | string name = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int[] shape = null) | int[] shape = null) | ||||