| @@ -40,7 +40,7 @@ namespace Tensorflow | |||
| string data_format= "channels_last", | |||
| int[] dilation_rate = null, | |||
| bool use_bias = true, | |||
| IActivation activation = null, | |||
| Activation activation = null, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| bool trainable = true, | |||
| @@ -53,20 +53,23 @@ namespace Tensorflow | |||
| if (bias_initializer == null) | |||
| bias_initializer = tf.zeros_initializer; | |||
| var layer = new Conv2D(filters, | |||
| kernel_size: kernel_size, | |||
| strides: strides, | |||
| padding: padding, | |||
| data_format: data_format, | |||
| dilation_rate: dilation_rate, | |||
| activation: activation, | |||
| use_bias: use_bias, | |||
| kernel_initializer: kernel_initializer, | |||
| bias_initializer: bias_initializer, | |||
| trainable: trainable, | |||
| name: name); | |||
| var layer = new Conv2D(new Conv2DArgs | |||
| { | |||
| Filters = filters, | |||
| KernelSize = kernel_size, | |||
| Strides = strides, | |||
| Padding = padding, | |||
| DataFormat = data_format, | |||
| DilationRate = dilation_rate, | |||
| Activation = activation, | |||
| UseBias = use_bias, | |||
| KernelInitializer = kernel_initializer, | |||
| BiasInitializer = bias_initializer, | |||
| Trainable = trainable, | |||
| Name = name | |||
| }); | |||
| return layer.apply(inputs).Item1; | |||
| return layer.Apply(inputs); | |||
| } | |||
| /// <summary> | |||
| @@ -140,13 +143,16 @@ namespace Tensorflow | |||
| string data_format = "channels_last", | |||
| string name = null) | |||
| { | |||
| var layer = new MaxPooling2D(pool_size: pool_size, | |||
| strides: strides, | |||
| padding: padding, | |||
| data_format: data_format, | |||
| name: name); | |||
| var layer = new MaxPooling2D(new MaxPooling2DArgs | |||
| { | |||
| PoolSize = pool_size, | |||
| Strides = strides, | |||
| Padding = padding, | |||
| DataFormat = data_format, | |||
| Name = name | |||
| }); | |||
| return layer.apply(inputs).Item1; | |||
| return layer.Apply(inputs); | |||
| } | |||
| /// <summary> | |||
| @@ -66,9 +66,8 @@ namespace Tensorflow | |||
| Tensor keep = null; | |||
| if (keep_prob != null) | |||
| keep = 1.0f - keep_prob; | |||
| var rate_tensor = keep; | |||
| return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); | |||
| return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name); | |||
| } | |||
| /// <summary> | |||
| @@ -41,6 +41,14 @@ namespace Tensorflow.Framework | |||
| name: name); | |||
| } | |||
| public static Tensor smart_cond(bool pred, | |||
| Func<Tensor> true_fn = null, | |||
| Func<Tensor> false_fn = null, | |||
| string name = null) | |||
| { | |||
| return pred ? true_fn() : false_fn(); | |||
| } | |||
| public static bool? smart_constant_value(Tensor pred) | |||
| { | |||
| var pred_value = tensor_util.constant_value(pred); | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class Conv2DArgs : ConvArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class ConvArgs : LayerArgs | |||
| { | |||
| public int Rank { get; set; } = 2; | |||
| public int Filters { get; set; } | |||
| public TensorShape KernelSize { get; set; } = 5; | |||
| /// <summary> | |||
| /// specifying the stride length of the convolution. | |||
| /// </summary> | |||
| public TensorShape Strides { get; set; } = (1, 1); | |||
| public string Padding { get; set; } = "valid"; | |||
| public string DataFormat { get; set; } | |||
| public TensorShape DilationRate { get; set; } = (1, 1); | |||
| public int Groups { get; set; } = 1; | |||
| public Activation Activation { get; set; } | |||
| public bool UseBias { get; set; } | |||
| public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | |||
| public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | |||
| public IInitializer KernelRegularizer { get; set; } | |||
| public IInitializer BiasRegularizer { get; set; } | |||
| public Action KernelConstraint { get; set; } | |||
| public Action BiasConstraint { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class DropoutArgs : LayerArgs | |||
| { | |||
| /// <summary> | |||
| /// Float between 0 and 1. Fraction of the input units to drop. | |||
| /// </summary> | |||
| public float Rate { get; set; } | |||
| /// <summary> | |||
| /// 1D integer tensor representing the shape of the | |||
| /// binary dropout mask that will be multiplied with the input. | |||
| /// </summary> | |||
| public TensorShape NoiseShape { get; set; } | |||
| /// <summary> | |||
| /// random seed. | |||
| /// </summary> | |||
| public int? Seed { get; set; } | |||
| public bool SupportsMasking { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class FlattenArgs : LayerArgs | |||
| { | |||
| public string DataFormat { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,12 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Layers; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class MaxPooling2DArgs : Pooling2DArgs | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,34 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public class Pooling2DArgs : LayerArgs | |||
| { | |||
| /// <summary> | |||
| /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. | |||
| /// </summary> | |||
| public IPoolFunction PoolFunction { get; set; } | |||
| /// <summary> | |||
| /// specifying the size of the pooling window. | |||
| /// </summary> | |||
| public TensorShape PoolSize { get; set; } | |||
| /// <summary> | |||
| /// specifying the strides of the pooling operation. | |||
| /// </summary> | |||
| public TensorShape Strides { get; set; } | |||
| /// <summary> | |||
| /// The padding method, either 'valid' or 'same'. | |||
| /// </summary> | |||
| public string Padding { get; set; } = "valid"; | |||
| /// <summary> | |||
| /// one of `channels_last` (default) or `channels_first`. | |||
| /// </summary> | |||
| public string DataFormat { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,39 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public class Flatten : Layer | |||
| { | |||
| FlattenArgs args; | |||
| InputSpec input_spec; | |||
| bool _channels_first; | |||
| public Flatten(FlattenArgs args) | |||
| : base(args) | |||
| { | |||
| args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); | |||
| input_spec = new InputSpec(min_ndim: 1); | |||
| _channels_first = args.DataFormat == "channels_first"; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| if (_channels_first) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); | |||
| } | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -25,5 +25,80 @@ namespace Tensorflow.Keras.Engine | |||
| _layers.Add(layer); | |||
| return layer; | |||
| } | |||
| protected Layer Conv2D(int filters, | |||
| int kernel_size, | |||
| TensorShape strides = null, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| TensorShape dilation_rate = null, | |||
| int groups = 1, | |||
| Activation activation = null, | |||
| bool use_bias = true, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| bool trainable = true, | |||
| string name = null) | |||
| { | |||
| var layer = new Conv2D(new Conv2DArgs | |||
| { | |||
| Filters = filters, | |||
| KernelSize = kernel_size, | |||
| Strides = strides ?? (1, 1), | |||
| Padding = padding, | |||
| DataFormat = data_format, | |||
| DilationRate = dilation_rate ?? (1, 1), | |||
| Groups = groups, | |||
| Activation = activation, | |||
| UseBias = use_bias, | |||
| KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||
| BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||
| Trainable = trainable, | |||
| Name = name | |||
| }); | |||
| _layers.Add(layer); | |||
| return layer; | |||
| } | |||
| protected Layer MaxPooling2D(TensorShape pool_size, | |||
| TensorShape strides, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| string name = null) | |||
| { | |||
| var layer = new MaxPooling2D(new MaxPooling2DArgs | |||
| { | |||
| PoolSize = pool_size, | |||
| Strides = strides, | |||
| Padding = padding, | |||
| DataFormat = data_format, | |||
| Name = name | |||
| }); | |||
| _layers.Add(layer); | |||
| return layer; | |||
| } | |||
| protected Layer Dropout(float rate, TensorShape noise_shape = null, int? seed = null) | |||
| { | |||
| var layer = new Dropout(new DropoutArgs | |||
| { | |||
| Rate = rate, | |||
| NoiseShape = noise_shape, | |||
| Seed = seed | |||
| }); | |||
| _layers.Add(layer); | |||
| return layer; | |||
| } | |||
| protected Layer Flatten() | |||
| { | |||
| var layer = new Flatten(new FlattenArgs()); | |||
| _layers.Add(layer); | |||
| return layer; | |||
| } | |||
| } | |||
| } | |||
| @@ -16,6 +16,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Operations; | |||
| @@ -23,87 +24,73 @@ using Tensorflow.Operations.Activation; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class Conv : Tensorflow.Layers.Layer | |||
| public class Conv : Layer | |||
| { | |||
| protected int rank; | |||
| protected int filters; | |||
| protected int[] kernel_size; | |||
| protected int[] strides; | |||
| protected string padding; | |||
| protected string data_format; | |||
| protected int[] dilation_rate; | |||
| protected IActivation activation; | |||
| protected bool use_bias; | |||
| protected IInitializer kernel_initializer; | |||
| protected IInitializer bias_initializer; | |||
| protected RefVariable kernel; | |||
| protected RefVariable bias; | |||
| ConvArgs args; | |||
| protected int rank => args.Rank; | |||
| protected int filters => args.Filters; | |||
| protected TensorShape kernel_size => args.KernelSize; | |||
| protected TensorShape strides => args.Strides; | |||
| protected string padding => args.Padding; | |||
| protected string data_format => args.DataFormat; | |||
| protected TensorShape dilation_rate => args.DilationRate; | |||
| protected Activation activation => args.Activation; | |||
| protected bool use_bias => args.UseBias; | |||
| protected IInitializer kernel_initializer => args.KernelInitializer; | |||
| protected IInitializer bias_initializer => args.BiasInitializer; | |||
| protected IVariableV1 kernel; | |||
| protected IVariableV1 bias; | |||
| protected Convolution _convolution_op; | |||
| string _tf_data_format; | |||
| public Conv(int rank, | |||
| int filters, | |||
| int[] kernel_size, | |||
| int[] strides = null, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| int[] dilation_rate = null, | |||
| IActivation activation = null, | |||
| bool use_bias = true, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| bool trainable = true, | |||
| string name = null) : base(trainable: trainable, name: name) | |||
| public Conv(ConvArgs args) : base(args) | |||
| { | |||
| this.rank = rank; | |||
| this.filters = filters; | |||
| this.kernel_size = kernel_size; | |||
| this.strides = strides; | |||
| this.padding = padding; | |||
| this.data_format = data_format; | |||
| this.dilation_rate = dilation_rate; | |||
| this.activation = activation; | |||
| this.use_bias = use_bias; | |||
| this.kernel_initializer = kernel_initializer; | |||
| this.bias_initializer = bias_initializer; | |||
| this.args = args; | |||
| args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size"); | |||
| args.Strides = conv_utils.normalize_tuple(args.Strides.dims, args.Rank, "strides"); | |||
| args.Padding = conv_utils.normalize_padding(args.Padding); | |||
| args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); | |||
| args.DilationRate = conv_utils.normalize_tuple(args.DilationRate.dims, args.Rank, "dilation_rate"); | |||
| inputSpec = new InputSpec(ndim: rank + 2); | |||
| _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); | |||
| } | |||
| protected override void build(TensorShape input_shape) | |||
| { | |||
| int channel_axis = data_format == "channels_first" ? 1 : -1; | |||
| int input_dim = channel_axis < 0 ? | |||
| int input_channel = channel_axis < 0 ? | |||
| input_shape.dims[input_shape.ndim + channel_axis] : | |||
| input_shape.dims[channel_axis]; | |||
| var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; | |||
| kernel = (RefVariable)add_weight(name: "kernel", | |||
| TensorShape kernel_shape = kernel_size.dims.concat(new int[] { input_channel / args.Groups, filters }); | |||
| kernel = add_weight(name: "kernel", | |||
| shape: kernel_shape, | |||
| initializer: kernel_initializer, | |||
| trainable: true, | |||
| dtype: DType); | |||
| if (use_bias) | |||
| bias = (RefVariable)add_weight(name: "bias", | |||
| bias = add_weight(name: "bias", | |||
| shape: new int[] { filters }, | |||
| initializer: bias_initializer, | |||
| trainable: true, | |||
| dtype: DType); | |||
| var axes = new Dictionary<int, int>(); | |||
| axes.Add(-1, input_dim); | |||
| axes.Add(-1, input_channel); | |||
| inputSpec = new InputSpec(ndim: rank + 2, axes: axes); | |||
| string op_padding; | |||
| string tf_padding; | |||
| if (padding == "causal") | |||
| op_padding = "valid"; | |||
| tf_padding = "VALID"; | |||
| else | |||
| op_padding = padding; | |||
| tf_padding = padding.ToUpper(); | |||
| var df = conv_utils.convert_data_format(data_format, rank + 2); | |||
| _convolution_op = nn_ops.Convolution(input_shape, | |||
| kernel.shape, | |||
| op_padding.ToUpper(), | |||
| tf_padding, | |||
| strides, | |||
| dilation_rate, | |||
| data_format: df); | |||
| data_format: _tf_data_format); | |||
| built = true; | |||
| } | |||
| @@ -119,12 +106,12 @@ namespace Tensorflow.Keras.Layers | |||
| } | |||
| else | |||
| { | |||
| outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC"); | |||
| outputs = nn_ops.bias_add(outputs, bias.AsTensor(), data_format: "NHWC"); | |||
| } | |||
| } | |||
| if (activation != null) | |||
| outputs = activation.Activate(outputs); | |||
| outputs = activation(outputs); | |||
| return outputs; | |||
| } | |||
| @@ -14,36 +14,15 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Operations.Activation; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class Conv2D : Conv | |||
| { | |||
| public Conv2D(int filters, | |||
| int[] kernel_size, | |||
| int[] strides = null, | |||
| string padding = "valid", | |||
| string data_format = "channels_last", | |||
| int[] dilation_rate = null, | |||
| IActivation activation = null, | |||
| bool use_bias = true, | |||
| IInitializer kernel_initializer = null, | |||
| IInitializer bias_initializer = null, | |||
| bool trainable = true, | |||
| string name = null) : base(2, | |||
| filters, | |||
| kernel_size, | |||
| strides: strides, | |||
| padding: padding, | |||
| data_format: data_format, | |||
| dilation_rate: dilation_rate, | |||
| activation: activation, | |||
| use_bias: use_bias, | |||
| kernel_initializer: kernel_initializer, | |||
| bias_initializer: bias_initializer, | |||
| trainable: trainable, | |||
| name: name) | |||
| public Conv2D(Conv2DArgs args) | |||
| : base(args) | |||
| { | |||
| } | |||
| @@ -0,0 +1,41 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class Dropout : Layer | |||
| { | |||
| DropoutArgs args; | |||
| public Dropout(DropoutArgs args) | |||
| : base(args) | |||
| { | |||
| this.args = args; | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| var output = tf_utils.smart_cond(is_training, | |||
| () => tf.nn.dropout(inputs, | |||
| noise_shape: get_noise_shape(inputs), | |||
| seed: args.Seed, | |||
| rate: args.Rate), | |||
| () => array_ops.identity(inputs)); | |||
| return output; | |||
| } | |||
| Tensor get_noise_shape(Tensor inputs) | |||
| { | |||
| if (args.NoiseShape == null) | |||
| return null; | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,21 +1,14 @@ | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class MaxPooling2D : Pooling2D | |||
| { | |||
| public MaxPooling2D( | |||
| int[] pool_size, | |||
| int[] strides, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| string name = null) : base(tf.nn.max_pool_fn, pool_size, | |||
| strides, | |||
| padding: padding, | |||
| data_format: data_format, | |||
| name: name) | |||
| public MaxPooling2D(MaxPooling2DArgs args) | |||
| : base(args) | |||
| { | |||
| args.PoolFunction = tf.nn.max_pool_fn; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,57 +14,49 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras.Layers | |||
| { | |||
| public class Pooling2D : Tensorflow.Layers.Layer | |||
| public class Pooling2D : Layer | |||
| { | |||
| private IPoolFunction pool_function; | |||
| private int[] pool_size; | |||
| private int[] strides; | |||
| private string padding; | |||
| private string data_format; | |||
| #pragma warning disable CS0108 // Member hides inherited member; missing new keyword | |||
| private InputSpec input_spec; | |||
| #pragma warning restore CS0108 // Member hides inherited member; missing new keyword | |||
| Pooling2DArgs args; | |||
| InputSpec input_spec; | |||
| public Pooling2D(IPoolFunction pool_function, | |||
| int[] pool_size, | |||
| int[] strides, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| string name = null) : base(name: name) | |||
| public Pooling2D(Pooling2DArgs args) | |||
| : base(args) | |||
| { | |||
| this.pool_function = pool_function; | |||
| this.pool_size = conv_utils.normalize_tuple(pool_size, 2, "pool_size"); | |||
| this.strides = conv_utils.normalize_tuple(strides, 2, "strides"); | |||
| this.padding = conv_utils.normalize_padding(padding); | |||
| this.data_format = conv_utils.normalize_data_format(data_format); | |||
| this.input_spec = new InputSpec(ndim: 4); | |||
| this.args = args; | |||
| args.PoolSize = conv_utils.normalize_tuple(args.PoolSize, 2, "pool_size"); | |||
| args.Strides = conv_utils.normalize_tuple(args.Strides, 2, "strides"); | |||
| args.Padding = conv_utils.normalize_padding(args.Padding); | |||
| args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); | |||
| input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
| { | |||
| int[] pool_shape; | |||
| if (data_format == "channels_last") | |||
| int[] strides; | |||
| if (args.DataFormat == "channels_last") | |||
| { | |||
| pool_shape = new int[] { 1, pool_size[0], pool_size[1], 1 }; | |||
| strides = new int[] { 1, strides[0], strides[1], 1 }; | |||
| pool_shape = new int[] { 1, args.PoolSize.dims[0], args.PoolSize.dims[1], 1 }; | |||
| strides = new int[] { 1, args.Strides.dims[0], args.Strides.dims[1], 1 }; | |||
| } | |||
| else | |||
| { | |||
| pool_shape = new int[] { 1, 1, pool_size[0], pool_size[1] }; | |||
| strides = new int[] { 1, 1, strides[0], strides[1] }; | |||
| pool_shape = new int[] { 1, 1, args.PoolSize.dims[0], args.PoolSize.dims[1] }; | |||
| strides = new int[] { 1, 1, args.Strides.dims[0], args.Strides.dims[1] }; | |||
| } | |||
| var outputs = pool_function.Apply( | |||
| var outputs = args.PoolFunction.Apply( | |||
| inputs, | |||
| ksize: pool_shape, | |||
| strides: strides, | |||
| padding: padding.ToUpper(), | |||
| data_format: conv_utils.convert_data_format(data_format, 4)); | |||
| padding: args.Padding.ToUpper(), | |||
| data_format: conv_utils.convert_data_format(args.DataFormat, 4)); | |||
| return outputs; | |||
| } | |||
| @@ -14,6 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| public class conv_utils | |||
| @@ -44,7 +46,10 @@ namespace Tensorflow.Keras.Utils | |||
| public static int[] normalize_tuple(int[] value, int n, string name) | |||
| { | |||
| return value; | |||
| if (value.Length == 1) | |||
| return Enumerable.Range(0, n).Select(x => value[0]).ToArray(); | |||
| else | |||
| return value; | |||
| } | |||
| public static string normalize_padding(string value) | |||
| @@ -54,6 +59,8 @@ namespace Tensorflow.Keras.Utils | |||
| public static string normalize_data_format(string value) | |||
| { | |||
| if (string.IsNullOrEmpty(value)) | |||
| return ImageDataFormat.channels_last.ToString(); | |||
| return value.ToLower(); | |||
| } | |||
| } | |||
| @@ -47,5 +47,16 @@ namespace Tensorflow.Keras.Utils | |||
| false_fn: false_fn, | |||
| name: name); | |||
| } | |||
| public static Tensor smart_cond(bool pred, | |||
| Func<Tensor> true_fn = null, | |||
| Func<Tensor> false_fn = null, | |||
| string name = null) | |||
| { | |||
| return smart_module.smart_cond(pred, | |||
| true_fn: true_fn, | |||
| false_fn: false_fn, | |||
| name: name); | |||
| } | |||
| } | |||
| } | |||
| @@ -76,7 +76,7 @@ namespace Tensorflow.Operations | |||
| name: name); | |||
| } | |||
| public Tensor __call__(Tensor inp, RefVariable filter) | |||
| public Tensor __call__(Tensor inp, IVariableV1 filter) | |||
| { | |||
| return conv_op.__call__(inp, filter); | |||
| } | |||
| @@ -67,12 +67,12 @@ namespace Tensorflow.Operations | |||
| } | |||
| } | |||
| public Tensor __call__(Tensor inp, RefVariable filter) | |||
| public Tensor __call__(Tensor inp, IVariableV1 filter) | |||
| { | |||
| return conv_op(new Conv2dParams | |||
| { | |||
| Input = inp, | |||
| Filter = filter, | |||
| Filter = filter.AsTensor(), | |||
| Strides = strides, | |||
| Padding = padding, | |||
| DataFormat = data_format, | |||
| @@ -68,7 +68,7 @@ namespace Tensorflow.Operations | |||
| } | |||
| } | |||
| public Tensor __call__(Tensor inp, RefVariable filter) | |||
| public Tensor __call__(Tensor inp, IVariableV1 filter) | |||
| { | |||
| return call.__call__(inp, filter); | |||
| } | |||
| @@ -68,7 +68,7 @@ namespace Tensorflow | |||
| /// <param name="seed"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public static Tensor dropout_v2(Tensor x, Tensor rate, Tensor noise_shape = null, int? seed = null, string name = null) | |||
| public static Tensor dropout_v2(Tensor x, float rate, Tensor noise_shape = null, int? seed = null, string name = null) | |||
| { | |||
| return tf_with(ops.name_scope(name, "dropout", x), scope => | |||
| { | |||
| @@ -78,11 +78,10 @@ namespace Tensorflow | |||
| throw new NotImplementedException($"x has to be a floating point tensor since it's going to" + | |||
| $" be scaled. Got a {x.dtype} tensor instead."); | |||
| rate = ops.convert_to_tensor(rate, dtype: x.dtype, name: "rate"); | |||
| // Do nothing if we know rate == 0 | |||
| var val = tensor_util.constant_value(rate); | |||
| if (!(val is null) && val.Data<float>()[0] == 0) | |||
| return x; | |||
| var keep_prob = 1 - rate; | |||
| var scale = 1 / keep_prob; | |||
| var scale_tensor = ops.convert_to_tensor(scale, dtype: x.dtype); | |||
| var ret = gen_math_ops.mul(x, scale_tensor); | |||
| noise_shape = _get_noise_shape(x, noise_shape); | |||
| @@ -92,13 +91,12 @@ namespace Tensorflow | |||
| // NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0) | |||
| // and subtract 1.0. | |||
| var random_tensor = random_ops.random_uniform(noise_shape, seed: seed, dtype: x.dtype); | |||
| var keep_prob = 1.0f - rate; | |||
| var scale = 1.0f / keep_prob; | |||
| // NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that | |||
| // float to be selected, hence we use a >= comparison. | |||
| var keep_mask = random_tensor >= rate; | |||
| var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | |||
| ret.set_shape(x.TensorShape); | |||
| ret = x * scale * math_ops.cast(keep_mask, x.dtype); | |||
| if (!tf.executing_eagerly()) | |||
| ret.set_shape(x.TensorShape); | |||
| return ret; | |||
| }); | |||
| } | |||