From cdf39c52a3a281da36665758b494e903541b397e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 15 Aug 2020 06:10:10 -0500 Subject: [PATCH] tf.keras.layers. #570 --- src/TensorFlowNET.Core/APIs/tf.layers.cs | 46 +++++----- src/TensorFlowNET.Core/APIs/tf.nn.cs | 3 +- .../Framework/smart_module.cs | 8 ++ .../Keras/ArgsDefinition/Conv2DArgs.cs | 11 +++ .../Keras/ArgsDefinition/ConvArgs.cs | 32 +++++++ .../Keras/ArgsDefinition/DropoutArgs.cs | 27 ++++++ .../Keras/ArgsDefinition/FlattenArgs.cs | 11 +++ .../Keras/ArgsDefinition/MaxPooling2D.cs | 12 +++ .../Keras/ArgsDefinition/Pooling2DArgs.cs | 34 +++++++ .../Keras/Engine/Flatten.cs | 39 ++++++++ .../Keras/Engine/Layer.Layers.cs | 75 ++++++++++++++++ src/TensorFlowNET.Core/Keras/Layers/Conv.cs | 89 ++++++++----------- src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs | 27 +----- .../Keras/Layers/Dropout.cs | 41 +++++++++ .../Keras/Layers/MaxPooling2D.cs | 17 ++-- .../Keras/Layers/Pooling2D.cs | 50 +++++------ .../Keras/Utils/conv_utils.cs | 9 +- .../Keras/Utils/tf_utils.cs | 11 +++ .../Operations/NnOps/Convolution.cs | 2 +- .../Operations/NnOps/_NonAtrousConvolution.cs | 4 +- .../Operations/NnOps/_WithSpaceToBatch.cs | 2 +- src/TensorFlowNET.Core/Operations/nn_ops.cs | 18 ++-- 22 files changed, 415 insertions(+), 153 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/DropoutArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/FlattenArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/MaxPooling2D.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling2DArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Flatten.cs create mode 100644 src/TensorFlowNET.Core/Keras/Layers/Dropout.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 085df3c5..46772dc9 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -40,7 +40,7 @@ namespace Tensorflow string data_format= "channels_last", int[] dilation_rate = null, bool use_bias = true, - IActivation activation = null, + Activation activation = null, IInitializer kernel_initializer = null, IInitializer bias_initializer = null, bool trainable = true, @@ -53,20 +53,23 @@ namespace Tensorflow if (bias_initializer == null) bias_initializer = tf.zeros_initializer; - var layer = new Conv2D(filters, - kernel_size: kernel_size, - strides: strides, - padding: padding, - data_format: data_format, - dilation_rate: dilation_rate, - activation: activation, - use_bias: use_bias, - kernel_initializer: kernel_initializer, - bias_initializer: bias_initializer, - trainable: trainable, - name: name); + var layer = new Conv2D(new Conv2DArgs + { + Filters = filters, + KernelSize = kernel_size, + Strides = strides, + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate, + Activation = activation, + UseBias = use_bias, + KernelInitializer = kernel_initializer, + BiasInitializer = bias_initializer, + Trainable = trainable, + Name = name + }); - return layer.apply(inputs).Item1; + return layer.Apply(inputs); } /// @@ -140,13 +143,16 @@ namespace Tensorflow string data_format = "channels_last", string name = null) { - var layer = new MaxPooling2D(pool_size: pool_size, - strides: strides, - padding: padding, - data_format: data_format, - name: name); + var layer = new MaxPooling2D(new MaxPooling2DArgs + { + PoolSize = pool_size, + Strides = strides, + Padding = padding, + DataFormat = data_format, + Name = name + }); - return layer.apply(inputs).Item1; + return layer.Apply(inputs); } /// diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 3d470ea3..3ab2a5b4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -66,9 +66,8 @@ namespace Tensorflow Tensor keep = null; if (keep_prob != null) keep = 1.0f - keep_prob; - var rate_tensor = keep; - return nn_ops.dropout_v2(x, rate: rate_tensor, noise_shape: noise_shape, seed: seed, name: name); + return nn_ops.dropout_v2(x, rate: rate.Value, noise_shape: noise_shape, seed: seed, name: name); } /// diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs index 0f1cb76e..3b837af2 100644 --- a/src/TensorFlowNET.Core/Framework/smart_module.cs +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -41,6 +41,14 @@ namespace Tensorflow.Framework name: name); } + public static Tensor smart_cond(bool pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return pred ? true_fn() : false_fn(); + } + public static bool? smart_constant_value(Tensor pred) { var pred_value = tensor_util.constant_value(pred); diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs new file mode 100644 index 00000000..be0ef74e --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class Conv2DArgs : ConvArgs + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs new file mode 100644 index 00000000..b96a6ba7 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ConvArgs : LayerArgs + { + public int Rank { get; set; } = 2; + public int Filters { get; set; } + public TensorShape KernelSize { get; set; } = 5; + + /// + /// specifying the stride length of the convolution. + /// + public TensorShape Strides { get; set; } = (1, 1); + + public string Padding { get; set; } = "valid"; + public string DataFormat { get; set; } + public TensorShape DilationRate { get; set; } = (1, 1); + public int Groups { get; set; } = 1; + public Activation Activation { get; set; } + public bool UseBias { get; set; } + public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; + public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; + public IInitializer KernelRegularizer { get; set; } + public IInitializer BiasRegularizer { get; set; } + public Action KernelConstraint { get; set; } + public Action BiasConstraint { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DropoutArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DropoutArgs.cs new file mode 100644 index 00000000..4317edf4 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DropoutArgs.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class DropoutArgs : LayerArgs + { + /// + /// Float between 0 and 1. Fraction of the input units to drop. + /// + public float Rate { get; set; } + + /// + /// 1D integer tensor representing the shape of the + /// binary dropout mask that will be multiplied with the input. + /// + public TensorShape NoiseShape { get; set; } + + /// + /// random seed. + /// + public int? Seed { get; set; } + + public bool SupportsMasking { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/FlattenArgs.cs new file mode 100644 index 00000000..3f31d532 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/FlattenArgs.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class FlattenArgs : LayerArgs + { + public string DataFormat { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/MaxPooling2D.cs new file mode 100644 index 00000000..c8c86b9a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/MaxPooling2D.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Layers; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class MaxPooling2DArgs : Pooling2DArgs + { + + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling2DArgs.cs new file mode 100644 index 00000000..3ff9092c --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling2DArgs.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class Pooling2DArgs : LayerArgs + { + /// + /// The pooling function to apply, e.g. `tf.nn.max_pool2d`. + /// + public IPoolFunction PoolFunction { get; set; } + + /// + /// specifying the size of the pooling window. + /// + public TensorShape PoolSize { get; set; } + + /// + /// specifying the strides of the pooling operation. + /// + public TensorShape Strides { get; set; } + + /// + /// The padding method, either 'valid' or 'same'. + /// + public string Padding { get; set; } = "valid"; + + /// + /// one of `channels_last` (default) or `channels_first`. + /// + public string DataFormat { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs new file mode 100644 index 00000000..45cfd8f2 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Flatten.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public class Flatten : Layer + { + FlattenArgs args; + InputSpec input_spec; + bool _channels_first; + + public Flatten(FlattenArgs args) + : base(args) + { + args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); + input_spec = new InputSpec(min_ndim: 1); + _channels_first = args.DataFormat == "channels_first"; + } + + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + { + if (_channels_first) + { + throw new NotImplementedException(""); + } + + if (tf.executing_eagerly()) + { + return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); + } + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs index 14f3f79a..14be45b0 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs @@ -25,5 +25,80 @@ namespace Tensorflow.Keras.Engine _layers.Add(layer); return layer; } + + protected Layer Conv2D(int filters, + int kernel_size, + TensorShape strides = null, + string padding = "valid", + string data_format = null, + TensorShape dilation_rate = null, + int groups = 1, + Activation activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + bool trainable = true, + string name = null) + { + var layer = new Conv2D(new Conv2DArgs + { + Filters = filters, + KernelSize = kernel_size, + Strides = strides ?? (1, 1), + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate ?? (1, 1), + Groups = groups, + Activation = activation, + UseBias = use_bias, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + BiasInitializer = bias_initializer ?? tf.zeros_initializer, + Trainable = trainable, + Name = name + }); + + _layers.Add(layer); + return layer; + } + + protected Layer MaxPooling2D(TensorShape pool_size, + TensorShape strides, + string padding = "valid", + string data_format = null, + string name = null) + { + var layer = new MaxPooling2D(new MaxPooling2DArgs + { + PoolSize = pool_size, + Strides = strides, + Padding = padding, + DataFormat = data_format, + Name = name + }); + + _layers.Add(layer); + return layer; + } + + protected Layer Dropout(float rate, TensorShape noise_shape = null, int? seed = null) + { + var layer = new Dropout(new DropoutArgs + { + Rate = rate, + NoiseShape = noise_shape, + Seed = seed + }); + + _layers.Add(layer); + return layer; + } + + protected Layer Flatten() + { + var layer = new Flatten(new FlattenArgs()); + + _layers.Add(layer); + return layer; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs index fa3b7505..c85c4379 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv.cs @@ -16,6 +16,7 @@ using System; using System.Collections.Generic; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; using Tensorflow.Operations; @@ -23,87 +24,73 @@ using Tensorflow.Operations.Activation; namespace Tensorflow.Keras.Layers { - public class Conv : Tensorflow.Layers.Layer + public class Conv : Layer { - protected int rank; - protected int filters; - protected int[] kernel_size; - protected int[] strides; - protected string padding; - protected string data_format; - protected int[] dilation_rate; - protected IActivation activation; - protected bool use_bias; - protected IInitializer kernel_initializer; - protected IInitializer bias_initializer; - protected RefVariable kernel; - protected RefVariable bias; + ConvArgs args; + protected int rank => args.Rank; + protected int filters => args.Filters; + protected TensorShape kernel_size => args.KernelSize; + protected TensorShape strides => args.Strides; + protected string padding => args.Padding; + protected string data_format => args.DataFormat; + protected TensorShape dilation_rate => args.DilationRate; + protected Activation activation => args.Activation; + protected bool use_bias => args.UseBias; + protected IInitializer kernel_initializer => args.KernelInitializer; + protected IInitializer bias_initializer => args.BiasInitializer; + protected IVariableV1 kernel; + protected IVariableV1 bias; protected Convolution _convolution_op; + string _tf_data_format; - public Conv(int rank, - int filters, - int[] kernel_size, - int[] strides = null, - string padding = "valid", - string data_format = null, - int[] dilation_rate = null, - IActivation activation = null, - bool use_bias = true, - IInitializer kernel_initializer = null, - IInitializer bias_initializer = null, - bool trainable = true, - string name = null) : base(trainable: trainable, name: name) + public Conv(ConvArgs args) : base(args) { - this.rank = rank; - this.filters = filters; - this.kernel_size = kernel_size; - this.strides = strides; - this.padding = padding; - this.data_format = data_format; - this.dilation_rate = dilation_rate; - this.activation = activation; - this.use_bias = use_bias; - this.kernel_initializer = kernel_initializer; - this.bias_initializer = bias_initializer; + this.args = args; + args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size"); + args.Strides = conv_utils.normalize_tuple(args.Strides.dims, args.Rank, "strides"); + args.Padding = conv_utils.normalize_padding(args.Padding); + args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); + args.DilationRate = conv_utils.normalize_tuple(args.DilationRate.dims, args.Rank, "dilation_rate"); inputSpec = new InputSpec(ndim: rank + 2); + _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); } protected override void build(TensorShape input_shape) { int channel_axis = data_format == "channels_first" ? 1 : -1; - int input_dim = channel_axis < 0 ? + int input_channel = channel_axis < 0 ? input_shape.dims[input_shape.ndim + channel_axis] : input_shape.dims[channel_axis]; - var kernel_shape = new int[] { kernel_size[0], kernel_size[1], input_dim, filters }; - kernel = (RefVariable)add_weight(name: "kernel", + TensorShape kernel_shape = kernel_size.dims.concat(new int[] { input_channel / args.Groups, filters }); + kernel = add_weight(name: "kernel", shape: kernel_shape, initializer: kernel_initializer, trainable: true, dtype: DType); if (use_bias) - bias = (RefVariable)add_weight(name: "bias", + bias = add_weight(name: "bias", shape: new int[] { filters }, initializer: bias_initializer, trainable: true, dtype: DType); var axes = new Dictionary(); - axes.Add(-1, input_dim); + axes.Add(-1, input_channel); inputSpec = new InputSpec(ndim: rank + 2, axes: axes); - string op_padding; + string tf_padding; if (padding == "causal") - op_padding = "valid"; + tf_padding = "VALID"; else - op_padding = padding; + tf_padding = padding.ToUpper(); - var df = conv_utils.convert_data_format(data_format, rank + 2); + _convolution_op = nn_ops.Convolution(input_shape, kernel.shape, - op_padding.ToUpper(), + tf_padding, strides, dilation_rate, - data_format: df); + data_format: _tf_data_format); built = true; } @@ -119,12 +106,12 @@ namespace Tensorflow.Keras.Layers } else { - outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC"); + outputs = nn_ops.bias_add(outputs, bias.AsTensor(), data_format: "NHWC"); } } if (activation != null) - outputs = activation.Activate(outputs); + outputs = activation(outputs); return outputs; } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs index 8bc83cce..9fe38ad2 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs @@ -14,36 +14,15 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Operations.Activation; namespace Tensorflow.Keras.Layers { public class Conv2D : Conv { - public Conv2D(int filters, - int[] kernel_size, - int[] strides = null, - string padding = "valid", - string data_format = "channels_last", - int[] dilation_rate = null, - IActivation activation = null, - bool use_bias = true, - IInitializer kernel_initializer = null, - IInitializer bias_initializer = null, - bool trainable = true, - string name = null) : base(2, - filters, - kernel_size, - strides: strides, - padding: padding, - data_format: data_format, - dilation_rate: dilation_rate, - activation: activation, - use_bias: use_bias, - kernel_initializer: kernel_initializer, - bias_initializer: bias_initializer, - trainable: trainable, - name: name) + public Conv2D(Conv2DArgs args) + : base(args) { } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs new file mode 100644 index 00000000..6449be48 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/Dropout.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + public class Dropout : Layer + { + DropoutArgs args; + + public Dropout(DropoutArgs args) + : base(args) + { + this.args = args; + } + + protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) + { + var output = tf_utils.smart_cond(is_training, + () => tf.nn.dropout(inputs, + noise_shape: get_noise_shape(inputs), + seed: args.Seed, + rate: args.Rate), + () => array_ops.identity(inputs)); + + return output; + } + + Tensor get_noise_shape(Tensor inputs) + { + if (args.NoiseShape == null) + return null; + + return null; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs index 27234078..1c3c3920 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs @@ -1,21 +1,14 @@ -using static Tensorflow.Binding; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { public class MaxPooling2D : Pooling2D { - public MaxPooling2D( - int[] pool_size, - int[] strides, - string padding = "valid", - string data_format = null, - string name = null) : base(tf.nn.max_pool_fn, pool_size, - strides, - padding: padding, - data_format: data_format, - name: name) + public MaxPooling2D(MaxPooling2DArgs args) + : base(args) { - + args.PoolFunction = tf.nn.max_pool_fn; } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs index 26f30885..030578d6 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs @@ -14,57 +14,49 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Layers { - public class Pooling2D : Tensorflow.Layers.Layer + public class Pooling2D : Layer { - private IPoolFunction pool_function; - private int[] pool_size; - private int[] strides; - private string padding; - private string data_format; -#pragma warning disable CS0108 // Member hides inherited member; missing new keyword - private InputSpec input_spec; -#pragma warning restore CS0108 // Member hides inherited member; missing new keyword + Pooling2DArgs args; + InputSpec input_spec; - public Pooling2D(IPoolFunction pool_function, - int[] pool_size, - int[] strides, - string padding = "valid", - string data_format = null, - string name = null) : base(name: name) + public Pooling2D(Pooling2DArgs args) + : base(args) { - this.pool_function = pool_function; - this.pool_size = conv_utils.normalize_tuple(pool_size, 2, "pool_size"); - this.strides = conv_utils.normalize_tuple(strides, 2, "strides"); - this.padding = conv_utils.normalize_padding(padding); - this.data_format = conv_utils.normalize_data_format(data_format); - this.input_spec = new InputSpec(ndim: 4); + this.args = args; + args.PoolSize = conv_utils.normalize_tuple(args.PoolSize, 2, "pool_size"); + args.Strides = conv_utils.normalize_tuple(args.Strides, 2, "strides"); + args.Padding = conv_utils.normalize_padding(args.Padding); + args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); + input_spec = new InputSpec(ndim: 4); } protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) { int[] pool_shape; - if (data_format == "channels_last") + int[] strides; + if (args.DataFormat == "channels_last") { - pool_shape = new int[] { 1, pool_size[0], pool_size[1], 1 }; - strides = new int[] { 1, strides[0], strides[1], 1 }; + pool_shape = new int[] { 1, args.PoolSize.dims[0], args.PoolSize.dims[1], 1 }; + strides = new int[] { 1, args.Strides.dims[0], args.Strides.dims[1], 1 }; } else { - pool_shape = new int[] { 1, 1, pool_size[0], pool_size[1] }; - strides = new int[] { 1, 1, strides[0], strides[1] }; + pool_shape = new int[] { 1, 1, args.PoolSize.dims[0], args.PoolSize.dims[1] }; + strides = new int[] { 1, 1, args.Strides.dims[0], args.Strides.dims[1] }; } - var outputs = pool_function.Apply( + var outputs = args.PoolFunction.Apply( inputs, ksize: pool_shape, strides: strides, - padding: padding.ToUpper(), - data_format: conv_utils.convert_data_format(data_format, 4)); + padding: args.Padding.ToUpper(), + data_format: conv_utils.convert_data_format(args.DataFormat, 4)); return outputs; } diff --git a/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs index ba27fb3d..8d799468 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/conv_utils.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Linq; + namespace Tensorflow.Keras.Utils { public class conv_utils @@ -44,7 +46,10 @@ namespace Tensorflow.Keras.Utils public static int[] normalize_tuple(int[] value, int n, string name) { - return value; + if (value.Length == 1) + return Enumerable.Range(0, n).Select(x => value[0]).ToArray(); + else + return value; } public static string normalize_padding(string value) @@ -54,6 +59,8 @@ namespace Tensorflow.Keras.Utils public static string normalize_data_format(string value) { + if (string.IsNullOrEmpty(value)) + return ImageDataFormat.channels_last.ToString(); return value.ToLower(); } } diff --git a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs index fc1b80ff..01098e62 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs @@ -47,5 +47,16 @@ namespace Tensorflow.Keras.Utils false_fn: false_fn, name: name); } + + public static Tensor smart_cond(bool pred, + Func true_fn = null, + Func false_fn = null, + string name = null) + { + return smart_module.smart_cond(pred, + true_fn: true_fn, + false_fn: false_fn, + name: name); + } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs index a6f419dd..be4aca3c 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs @@ -76,7 +76,7 @@ namespace Tensorflow.Operations name: name); } - public Tensor __call__(Tensor inp, RefVariable filter) + public Tensor __call__(Tensor inp, IVariableV1 filter) { return conv_op.__call__(inp, filter); } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs index c5bcb2cf..f947cdbc 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs @@ -67,12 +67,12 @@ namespace Tensorflow.Operations } } - public Tensor __call__(Tensor inp, RefVariable filter) + public Tensor __call__(Tensor inp, IVariableV1 filter) { return conv_op(new Conv2dParams { Input = inp, - Filter = filter, + Filter = filter.AsTensor(), Strides = strides, Padding = padding, DataFormat = data_format, diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs index e9b6126c..8ae4ee36 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs @@ -68,7 +68,7 @@ namespace Tensorflow.Operations } } - public Tensor __call__(Tensor inp, RefVariable filter) + public Tensor __call__(Tensor inp, IVariableV1 filter) { return call.__call__(inp, filter); } diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 124fd72b..ce3875cc 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -68,7 +68,7 @@ namespace Tensorflow /// /// /// - public static Tensor dropout_v2(Tensor x, Tensor rate, Tensor noise_shape = null, int? seed = null, string name = null) + public static Tensor dropout_v2(Tensor x, float rate, Tensor noise_shape = null, int? seed = null, string name = null) { return tf_with(ops.name_scope(name, "dropout", x), scope => { @@ -78,11 +78,10 @@ namespace Tensorflow throw new NotImplementedException($"x has to be a floating point tensor since it's going to" + $" be scaled. Got a {x.dtype} tensor instead."); - rate = ops.convert_to_tensor(rate, dtype: x.dtype, name: "rate"); - // Do nothing if we know rate == 0 - var val = tensor_util.constant_value(rate); - if (!(val is null) && val.Data()[0] == 0) - return x; + var keep_prob = 1 - rate; + var scale = 1 / keep_prob; + var scale_tensor = ops.convert_to_tensor(scale, dtype: x.dtype); + var ret = gen_math_ops.mul(x, scale_tensor); noise_shape = _get_noise_shape(x, noise_shape); @@ -92,13 +91,12 @@ namespace Tensorflow // NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0) // and subtract 1.0. var random_tensor = random_ops.random_uniform(noise_shape, seed: seed, dtype: x.dtype); - var keep_prob = 1.0f - rate; - var scale = 1.0f / keep_prob; // NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that // float to be selected, hence we use a >= comparison. var keep_mask = random_tensor >= rate; - var ret = x * scale * math_ops.cast(keep_mask, x.dtype); - ret.set_shape(x.TensorShape); + ret = x * scale * math_ops.cast(keep_mask, x.dtype); + if (!tf.executing_eagerly()) + ret.set_shape(x.TensorShape); return ret; }); }