| @@ -57,5 +57,21 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| } | } | ||||
| protected virtual void add_weight(string name, | |||||
| int[] shape, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| IInitializer initializer = null, | |||||
| bool? trainable = null, | |||||
| Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null) | |||||
| { | |||||
| _add_variable_with_custom_getter(name, | |||||
| shape, | |||||
| dtype: dtype, | |||||
| getter: getter, | |||||
| overwrite: true, | |||||
| initializer: initializer, | |||||
| trainable: trainable.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -53,7 +53,11 @@ namespace Tensorflow.Keras.Layers | |||||
| int channel_axis = data_format == "channels_first" ? 1 : -1; | int channel_axis = data_format == "channels_first" ? 1 : -1; | ||||
| int input_dim = input_shape.Dimensions[input_shape.NDim - 1]; | int input_dim = input_shape.Dimensions[input_shape.NDim - 1]; | ||||
| var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; | var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; | ||||
| add_weight(); | |||||
| add_weight(name: "kernel", | |||||
| shape: kernel_shape, | |||||
| initializer: kernel_initializer, | |||||
| trainable: true, | |||||
| dtype: _dtype); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -68,7 +68,11 @@ namespace Tensorflow.Layers | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| protected virtual void add_weight() | |||||
| protected virtual void add_weight(string name, | |||||
| int[] shape, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| IInitializer initializer = null, | |||||
| bool? trainable = null) | |||||
| { | { | ||||
| var default_graph = ops.get_default_graph(); | var default_graph = ops.get_default_graph(); | ||||
| Graph init_graph = null; | Graph init_graph = null; | ||||
| @@ -84,7 +88,9 @@ namespace Tensorflow.Layers | |||||
| existing_variables = variables.global_variables().ToArray(); | existing_variables = variables.global_variables().ToArray(); | ||||
| } | } | ||||
| var dtype = TF_DataType.TF_FLOAT; | |||||
| if(dtype == TF_DataType.DtInvalid) | |||||
| dtype = TF_DataType.TF_FLOAT; | |||||
| _set_scope(); | _set_scope(); | ||||
| var reuse = built || (_reuse != null && _reuse.Value); | var reuse = built || (_reuse != null && _reuse.Value); | ||||
| Python.with(tf.variable_scope(_scope, | Python.with(tf.variable_scope(_scope, | ||||
| @@ -94,8 +100,19 @@ namespace Tensorflow.Layers | |||||
| _current_scope = scope; | _current_scope = scope; | ||||
| Python.with(ops.name_scope(_name_scope()), delegate | Python.with(ops.name_scope(_name_scope()), delegate | ||||
| { | { | ||||
| base.add_weight(name, | |||||
| shape, | |||||
| dtype: dtype, | |||||
| initializer: initializer, | |||||
| trainable: trainable, | |||||
| getter: (name1, shape1, dtype1, initializer1, trainable1) => | |||||
| { | |||||
| return tf.get_variable(name1, | |||||
| shape: new TensorShape(shape1), | |||||
| dtype: dtype1, | |||||
| initializer: initializer1, | |||||
| trainable: trainable1); | |||||
| }); | |||||
| }); | }); | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| @@ -64,7 +65,16 @@ namespace Tensorflow.Operations.Initializers | |||||
| if (shape.Length == 2) | if (shape.Length == 2) | ||||
| return (shape[0], shape[1]); | return (shape[0], shape[1]); | ||||
| else | else | ||||
| throw new NotImplementedException("VarianceScaling._compute_fans"); | |||||
| { | |||||
| // Assuming convolution kernels (2D, 3D, or more). | |||||
| // kernel shape: (..., input_depth, depth) | |||||
| int receptive_field_size = 1; | |||||
| foreach (var dim in shape.Take(2)) | |||||
| receptive_field_size *= dim; | |||||
| var fan_in = shape[shape.Length - 2] * receptive_field_size; | |||||
| var fan_out = shape[shape.Length - 1] * receptive_field_size; | |||||
| return (fan_in, fan_out); | |||||
| } | |||||
| } | } | ||||
| public virtual object get_config() | public virtual object get_config() | ||||
| @@ -4,7 +4,22 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class CheckpointableBase | |||||
| public abstract class CheckpointableBase | |||||
| { | { | ||||
| /// <summary> | |||||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| protected virtual RefVariable _add_variable_with_custom_getter(string name, | |||||
| int[] shape, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| IInitializer initializer = null, | |||||
| Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null, | |||||
| bool overwrite = false, | |||||
| bool trainable = false) | |||||
| { | |||||
| var new_variable = getter(name, shape, dtype, initializer, trainable); | |||||
| throw new NotImplementedException("_add_variable_with_custom_getter"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -48,6 +48,7 @@ namespace Tensorflow | |||||
| shape: shape, | shape: shape, | ||||
| dtype: dtype, | dtype: dtype, | ||||
| initializer: initializer, | initializer: initializer, | ||||
| reuse: resue, | |||||
| trainable: trainable, | trainable: trainable, | ||||
| synchronization: synchronization, | synchronization: synchronization, | ||||
| aggregation: aggregation); | aggregation: aggregation); | ||||
| @@ -24,6 +24,7 @@ namespace Tensorflow | |||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| object initializer = null, // IInitializer or Tensor | object initializer = null, // IInitializer or Tensor | ||||
| bool? reuse = null, | |||||
| bool? trainable = null, | bool? trainable = null, | ||||
| bool validate_shape = true, | bool validate_shape = true, | ||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| @@ -100,7 +101,7 @@ namespace Tensorflow | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| VariableAggregation aggregation = VariableAggregation.NONE) | VariableAggregation aggregation = VariableAggregation.NONE) | ||||
| { | { | ||||
| bool initializing_from_value = true; | |||||
| bool initializing_from_value = false; | |||||
| if (use_resource == null) | if (use_resource == null) | ||||
| use_resource = false; | use_resource = false; | ||||