| @@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| IInitializer initializer = null, | |||
| bool? trainable = null, | |||
| Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null) | |||
| Func<VariableArgs, IVariableV1> getter = null) | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = TF_DataType.TF_FLOAT; | |||
| @@ -259,7 +259,7 @@ namespace Tensorflow.Keras.Engine | |||
| trainable = true; | |||
| // Initialize variable when no initializer provided | |||
| if(initializer == null) | |||
| if (initializer == null) | |||
| { | |||
| // If dtype is DT_FLOAT, provide a uniform unit scaling initializer | |||
| if (dtype.is_floating()) | |||
| @@ -269,13 +269,18 @@ namespace Tensorflow.Keras.Engine | |||
| else | |||
| throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); | |||
| } | |||
| var variable = _add_variable_with_custom_getter(name, | |||
| shape, | |||
| dtype: dtype, | |||
| getter: (getter == null) ? base_layer_utils.make_variable : getter, | |||
| overwrite: true, | |||
| initializer: initializer, | |||
| trainable: trainable.Value); | |||
| var variable = _add_variable_with_custom_getter(new VariableArgs | |||
| { | |||
| Name = name, | |||
| Shape = shape, | |||
| DType = dtype, | |||
| Getter = getter ?? base_layer_utils.make_variable, | |||
| Overwrite = true, | |||
| Initializer = initializer, | |||
| Trainable = trainable.Value | |||
| }); | |||
| //backend.track_variable(variable); | |||
| if (trainable == true) | |||
| _trainable_weights.Add(variable); | |||
| @@ -199,8 +199,8 @@ namespace Tensorflow.Keras.Optimizers | |||
| } | |||
| } | |||
| ResourceVariable add_weight(string name, | |||
| TensorShape shape, | |||
| ResourceVariable add_weight(string name, | |||
| TensorShape shape, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| IInitializer initializer = null, | |||
| bool trainable = false, | |||
| @@ -213,16 +213,19 @@ namespace Tensorflow.Keras.Optimizers | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = TF_DataType.TF_FLOAT; | |||
| var variable = _add_variable_with_custom_getter(name: name, | |||
| shape: shape, | |||
| getter: base_layer_utils.make_variable, | |||
| dtype: dtype, | |||
| overwrite: true, | |||
| initializer: initializer, | |||
| trainable: trainable, | |||
| use_resource: true, | |||
| synchronization: synchronization, | |||
| aggregation: aggregation); | |||
| var variable = _add_variable_with_custom_getter(new VariableArgs | |||
| { | |||
| Name = name, | |||
| Shape = shape, | |||
| Getter = base_layer_utils.make_variable, | |||
| DType = dtype, | |||
| Overwrite = true, | |||
| Initializer = initializer, | |||
| Trainable = trainable, | |||
| UseResource = true, | |||
| Synchronization = synchronization, | |||
| Aggregation = aggregation | |||
| }); | |||
| return variable as ResourceVariable; | |||
| } | |||
| @@ -26,32 +26,26 @@ namespace Tensorflow.Keras.Utils | |||
| /// <summary> | |||
| /// Adds a new variable to the layer. | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <param name="shape"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="initializer"></param> | |||
| /// <param name="trainable"></param> | |||
| /// <param name="args"></param> | |||
| /// <returns></returns> | |||
| public static IVariableV1 make_variable(string name, | |||
| int[] shape, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| IInitializer initializer = null, | |||
| bool trainable = true) | |||
| public static IVariableV1 make_variable(VariableArgs args) | |||
| { | |||
| #pragma warning disable CS0219 // Variable is assigned but its value is never used | |||
| var initializing_from_value = false; | |||
| bool use_resource = true; | |||
| #pragma warning restore CS0219 // Variable is assigned but its value is never used | |||
| ops.init_scope(); | |||
| Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); | |||
| Func<Tensor> init_val = () => args.Initializer.call(args.Shape, dtype: args.DType); | |||
| var variable_dtype = dtype.as_base_dtype(); | |||
| var variable_dtype = args.DType.as_base_dtype(); | |||
| var v = tf.Variable(init_val, | |||
| dtype: dtype, | |||
| shape: shape, | |||
| name: name); | |||
| dtype: args.DType, | |||
| shape: args.Shape, | |||
| name: args.Name, | |||
| trainable: args.Trainable, | |||
| validate_shape: args.ValidateShape, | |||
| use_resource: args.UseResource); | |||
| return v; | |||
| } | |||
| @@ -167,12 +167,12 @@ namespace Tensorflow.Layers | |||
| dtype: dtype, | |||
| initializer: initializer, | |||
| trainable: trainable, | |||
| getter: (name1, shape1, dtype1, initializer1, trainable1) => | |||
| tf.compat.v1.get_variable(name1, | |||
| shape: new TensorShape(shape1), | |||
| dtype: dtype1, | |||
| initializer: initializer1, | |||
| trainable: trainable1) | |||
| getter: (args) => | |||
| tf.compat.v1.get_variable(args.Name, | |||
| shape: args.Shape, | |||
| dtype: args.DType, | |||
| initializer: args.Initializer, | |||
| trainable: args.Trainable) | |||
| ); | |||
| //if (init_graph != null) | |||
| @@ -27,16 +27,7 @@ namespace Tensorflow.Train | |||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| protected virtual IVariableV1 _add_variable_with_custom_getter(string name, | |||
| int[] shape, | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| IInitializer initializer = null, | |||
| Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null, | |||
| bool overwrite = false, | |||
| bool trainable = false, | |||
| bool use_resource = false, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| VariableAggregation aggregation = VariableAggregation.None) | |||
| protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args) | |||
| { | |||
| ops.init_scope(); | |||
| #pragma warning disable CS0219 // Variable is assigned but its value is never used | |||
| @@ -50,15 +41,15 @@ namespace Tensorflow.Train | |||
| checkpoint_initializer = null; | |||
| IVariableV1 new_variable; | |||
| new_variable = getter(name, shape, dtype, initializer, trainable); | |||
| new_variable = args.Getter(args); | |||
| // If we set an initializer and the variable processed it, tracking will not | |||
| // assign again. It will add this variable to our dependencies, and if there | |||
| // is a non-trivial restoration queued, it will handle that. This also | |||
| // handles slot variables. | |||
| if (!overwrite || new_variable is RefVariable) | |||
| return _track_checkpointable(new_variable, name: name, | |||
| overwrite: overwrite); | |||
| if (!args.Overwrite || new_variable is RefVariable) | |||
| return _track_checkpointable(new_variable, name: args.Name, | |||
| overwrite: args.Overwrite); | |||
| else | |||
| return new_variable; | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class VariableArgs | |||
| { | |||
| public object InitialValue { get; set; } | |||
| public Func<VariableArgs, IVariableV1> Getter { get; set; } | |||
| public string Name { get; set; } | |||
| public TensorShape Shape { get; set; } | |||
| public TF_DataType DType { get; set; } = TF_DataType.DtInvalid; | |||
| public IInitializer Initializer { get; set; } | |||
| public bool Trainable { get; set; } | |||
| public bool ValidateShape { get; set; } = true; | |||
| public bool UseResource { get; set; } = true; | |||
| public bool Overwrite { get; set; } | |||
| public List<string> Collections { get; set; } | |||
| public string CachingDevice { get; set; } = ""; | |||
| public VariableDef VariableDef { get; set; } | |||
| public string ImportScope { get; set; } = ""; | |||
| public VariableSynchronization Synchronization { get; set; } = VariableSynchronization.Auto; | |||
| public VariableAggregation Aggregation { get; set; } = VariableAggregation.None; | |||
| } | |||
| } | |||
| @@ -62,6 +62,7 @@ namespace Tensorflow | |||
| public ResourceVariable Variable<T>(T data, | |||
| bool trainable = true, | |||
| bool validate_shape = true, | |||
| bool use_resource = true, | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| int[] shape = null) | |||