Keras Conv1Dtags/v0.40-tf2.4-tstring
| @@ -0,0 +1,7 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class Conv1DArgs : ConvolutionalArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,81 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| namespace Tensorflow.Operations | |||||
| { | |||||
| public class Conv1dParams | |||||
| { | |||||
| public string Name { get; set; } | |||||
| /// <summary> | |||||
| /// An optional `string` from: `"NHWC", "NCHW"`. Defaults to `"NHWC"`. | |||||
| /// Specify the data format of the input and output data. With the | |||||
| /// default format "NHWC", the data is stored in the order of: | |||||
| /// [batch, height, width, channels]. | |||||
| /// </summary> | |||||
| public string DataFormat { get; set; } = "NHWC"; | |||||
| /// <summary> | |||||
| /// Must be one of the following types: `half`, `bfloat16`, `float32`, `float64`. | |||||
| /// A 4-D tensor. The dimension order is interpreted according to the value | |||||
| /// </summary> | |||||
| public Tensor Input { get; set; } | |||||
| /// <summary> | |||||
| /// An integer vector representing the shape of `input` | |||||
| /// </summary> | |||||
| public Tensor InputSizes { get; set; } | |||||
| /// <summary> | |||||
| /// A 4-D tensor of shape | |||||
| /// </summary> | |||||
| public IVariableV1 Filter { get; set; } | |||||
| /// <summary> | |||||
| /// An integer vector representing the tensor shape of `filter` | |||||
| /// </summary> | |||||
| public Tensor FilterSizes { get; set; } | |||||
| /// <summary> | |||||
| /// A `Tensor`. Must have the same type as `filter`. | |||||
| /// 4-D with shape `[batch, out_height, out_width, out_channels]`. | |||||
| /// </summary> | |||||
| public Tensor OutBackProp { get; set; } | |||||
| /// <summary> | |||||
| /// The stride of the sliding window for each | |||||
| /// dimension of `input`. The dimension order is determined by the value of | |||||
| /// `data_format`, see below for details. | |||||
| /// </summary> | |||||
| public int[] Strides { get; set; } | |||||
| /// <summary> | |||||
| /// A `string` from: `"SAME", "VALID", "EXPLICIT"`. | |||||
| /// </summary> | |||||
| public string Padding { get; set; } | |||||
| public int[] ExplicitPaddings { get; set; } = new int[0]; | |||||
| public bool UseCudnnOnGpu { get; set; } = true; | |||||
| public int[] Dilations { get; set; } = new int[] { 1, 1, 1 }; | |||||
| public Conv1dParams() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -41,8 +41,15 @@ namespace Tensorflow.Operations | |||||
| var filters_rank = filters.shape.rank; | var filters_rank = filters.shape.rank; | ||||
| var inputs_rank = input.shape.rank; | var inputs_rank = input.shape.rank; | ||||
| var num_spatial_dims = args.NumSpatialDims; | var num_spatial_dims = args.NumSpatialDims; | ||||
| if (num_spatial_dims == Unknown) | |||||
| if (args.Rank == 1) | |||||
| { | |||||
| // Special case: Conv1D | |||||
| num_spatial_dims = 1; | |||||
| } | |||||
| else if (num_spatial_dims == Unknown) | |||||
| { | |||||
| num_spatial_dims = filters_rank - 2; | num_spatial_dims = filters_rank - 2; | ||||
| } | |||||
| // Channel dimension. | // Channel dimension. | ||||
| var num_batch_dims = inputs_rank - num_spatial_dims - 1; | var num_batch_dims = inputs_rank - num_spatial_dims - 1; | ||||
| @@ -50,16 +57,16 @@ namespace Tensorflow.Operations | |||||
| throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + | throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + | ||||
| $"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}."); | $"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}."); | ||||
| var channel_index = num_batch_dims + num_spatial_dims; | |||||
| var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index); | |||||
| var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index); | |||||
| Tensor result = null; | Tensor result = null; | ||||
| tf_with(ops.name_scope(name, default_name: null), scope => | tf_with(ops.name_scope(name, default_name: null), scope => | ||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| if (num_spatial_dims == 2) | if (num_spatial_dims == 2) | ||||
| { | { | ||||
| var channel_index = num_batch_dims + num_spatial_dims; | |||||
| var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index).ToArray(); | |||||
| var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index).ToArray(); | |||||
| result = gen_nn_ops.conv2d(new Conv2dParams | result = gen_nn_ops.conv2d(new Conv2dParams | ||||
| { | { | ||||
| Input = input, | Input = input, | ||||
| @@ -72,13 +79,37 @@ namespace Tensorflow.Operations | |||||
| }); | }); | ||||
| } | } | ||||
| else | else | ||||
| throw new NotImplementedException(""); | |||||
| { | |||||
| var channel_first = data_format == "NCW"; | |||||
| var spatial_start_dim = channel_first ? -2 : -3; | |||||
| var channel_index = channel_first ? 1 : 2; | |||||
| var dilations = _get_sequence(args.DilationRate, 1, channel_index); | |||||
| var strides = _get_sequence(args.Strides, 1, channel_index); | |||||
| strides.Insert(0, 1); | |||||
| dilations.Insert(0, 1); | |||||
| var expanded = tf.expand_dims(input, spatial_start_dim); | |||||
| result = gen_nn_ops.conv2d(new Conv2dParams | |||||
| { | |||||
| Input = expanded, | |||||
| Filter = filters, | |||||
| Strides = strides.ToArray(), | |||||
| Padding = padding, | |||||
| DataFormat = channel_first ? "NCHW" : "NHWC", | |||||
| Dilations = dilations.ToArray(), | |||||
| Name = name | |||||
| }); | |||||
| result = tf.squeeze(result, squeeze_dims: spatial_start_dim); | |||||
| } | |||||
| }); | }); | ||||
| return result; | return result; | ||||
| } | } | ||||
| int[] _get_sequence(int[] value, int n, int channel_index) | |||||
| IList<int> _get_sequence(int[] value, int n, int channel_index) | |||||
| { | { | ||||
| var seq = new List<int>(); | var seq = new List<int>(); | ||||
| @@ -95,7 +126,7 @@ namespace Tensorflow.Operations | |||||
| seq.Add(1); | seq.Add(1); | ||||
| } | } | ||||
| return seq.ToArray(); | |||||
| return seq; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -27,9 +27,11 @@ namespace Tensorflow | |||||
| public static ConvolutionInternal convolution_internal(string padding, | public static ConvolutionInternal convolution_internal(string padding, | ||||
| int[] strides, | int[] strides, | ||||
| int[] dilation_rate, | int[] dilation_rate, | ||||
| int rank, | |||||
| string name = null, | string name = null, | ||||
| string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs | string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs | ||||
| { | { | ||||
| Rank = rank, | |||||
| Padding = padding, | Padding = padding, | ||||
| Strides = strides, | Strides = strides, | ||||
| DilationRate = dilation_rate, | DilationRate = dilation_rate, | ||||
| @@ -0,0 +1,28 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class Conv1D : Convolutional | |||||
| { | |||||
| public Conv1D(Conv1DArgs args) : base(args) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -93,6 +93,7 @@ namespace Tensorflow.Keras.Layers | |||||
| _convolution_op = nn_ops.convolution_internal(tf_padding, | _convolution_op = nn_ops.convolution_internal(tf_padding, | ||||
| strides, | strides, | ||||
| dilation_rate, | dilation_rate, | ||||
| rank, | |||||
| data_format: _tf_data_format, | data_format: _tf_data_format, | ||||
| name: tf_op_name); | name: tf_op_name); | ||||
| @@ -67,6 +67,113 @@ namespace Tensorflow.Keras.Layers | |||||
| Name = name | Name = name | ||||
| }); | }); | ||||
| /// <summary> | |||||
| /// 1D convolution layer (e.g. temporal convolution). | |||||
| /// This layer creates a convolution kernel that is convolved with the layer input over a single spatial(or temporal) dimension to produce a tensor of outputs.If use_bias is True, a bias vector is created and added to the outputs.Finally, if activation is not None, it is applied to the outputs as well. | |||||
| /// </summary> | |||||
| /// <param name="filters">Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution)</param> | |||||
| /// <param name="kernel_size">An integer specifying the width of the 1D convolution window.</param> | |||||
| /// <param name="strides">An integer specifying the stride of the convolution window . Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.</param> | |||||
| /// <param name="padding">one of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input.</param> | |||||
| /// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. If you never set it, then it will be channels_last.</param> | |||||
| /// <param name="dilation_rate">An integer specifying the dilation rate to use for dilated convolution.Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.</param> | |||||
| /// <param name="groups">A positive integer specifying the number of groups in which the input is split along the channel axis. Each group is convolved separately with filters / groups filters. The output is the concatenation of all the groups results along the channel axis. Input channels and filters must both be divisible by groups.</param> | |||||
| /// <param name="activation">Activation function to use. If you don't specify anything, no activation is applied (see keras.activations).</param> | |||||
| /// <param name="use_bias">Boolean, whether the layer uses a bias vector.</param> | |||||
| /// <param name="kernel_initializer">Initializer for the kernel weights matrix (see keras.initializers).</param> | |||||
| /// <param name="bias_initializer">Initializer for the bias vector (see keras.initializers).</param> | |||||
| /// <param name="kernel_regularizer">Regularizer function applied to the kernel weights matrix (see keras.regularizers).</param> | |||||
| /// <param name="bias_regularizer">Regularizer function applied to the bias vector (see keras.regularizers).</param> | |||||
| /// <param name="activity_regularizer">Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | |||||
| /// <returns>A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).</returns> | |||||
| public Conv1D Conv1D(int filters, | |||||
| int? kernel_size = null, | |||||
| int? strides = null, | |||||
| string padding = "valid", | |||||
| string data_format = null, | |||||
| int? dilation_rate = null, | |||||
| int groups = 1, | |||||
| Activation activation = null, | |||||
| bool use_bias = true, | |||||
| IInitializer kernel_initializer = null, | |||||
| IInitializer bias_initializer = null, | |||||
| IRegularizer kernel_regularizer = null, | |||||
| IRegularizer bias_regularizer = null, | |||||
| IRegularizer activity_regularizer = null) | |||||
| { | |||||
| // Special case: Conv1D will be implemented as Conv2D with H=1, so we need to add a 1-sized dimension to the kernel. | |||||
| // Lower-level logic handles the stride and dilation_rate, but the kernel_size needs to be set properly here. | |||||
| var kernel = (kernel_size == null) ? (1, 5) : (1, kernel_size.Value); | |||||
| return new Conv1D(new Conv1DArgs | |||||
| { | |||||
| Rank = 1, | |||||
| Filters = filters, | |||||
| KernelSize = kernel, | |||||
| Strides = strides == null ? 1 : strides, | |||||
| Padding = padding, | |||||
| DataFormat = data_format, | |||||
| DilationRate = dilation_rate == null ? 1 : dilation_rate, | |||||
| Groups = groups, | |||||
| UseBias = use_bias, | |||||
| KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, | |||||
| BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | |||||
| KernelRegularizer = kernel_regularizer, | |||||
| BiasRegularizer = bias_regularizer, | |||||
| ActivityRegularizer = activity_regularizer, | |||||
| Activation = activation ?? keras.activations.Linear | |||||
| }); | |||||
| } | |||||
| /// <summary> | |||||
| /// 1D convolution layer (e.g. temporal convolution). | |||||
| /// This layer creates a convolution kernel that is convolved with the layer input over a single spatial(or temporal) dimension to produce a tensor of outputs.If use_bias is True, a bias vector is created and added to the outputs.Finally, if activation is not None, it is applied to the outputs as well. | |||||
| /// </summary> | |||||
| /// <param name="filters">Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution)</param> | |||||
| /// <param name="kernel_size">An integer specifying the width of the 1D convolution window.</param> | |||||
| /// <param name="strides">An integer specifying the stride of the convolution window . Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.</param> | |||||
| /// <param name="padding">one of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input.</param> | |||||
| /// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. If you never set it, then it will be channels_last.</param> | |||||
| /// <param name="dilation_rate">An integer specifying the dilation rate to use for dilated convolution.Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.</param> | |||||
| /// <param name="groups">A positive integer specifying the number of groups in which the input is split along the channel axis. Each group is convolved separately with filters / groups filters. The output is the concatenation of all the groups results along the channel axis. Input channels and filters must both be divisible by groups.</param> | |||||
| /// <param name="activation">Activation function to use. If you don't specify anything, no activation is applied (see keras.activations).</param> | |||||
| /// <param name="use_bias">Boolean, whether the layer uses a bias vector.</param> | |||||
| /// <param name="kernel_initializer">Initializer for the kernel weights matrix (see keras.initializers).</param> | |||||
| /// <param name="bias_initializer">Initializer for the bias vector (see keras.initializers).</param> | |||||
| /// <returns>A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).</returns> | |||||
| public Conv1D Conv1D(int filters, | |||||
| int? kernel_size = null, | |||||
| int? strides = null, | |||||
| string padding = "valid", | |||||
| string data_format = null, | |||||
| int? dilation_rate = null, | |||||
| int groups = 1, | |||||
| string activation = null, | |||||
| bool use_bias = true, | |||||
| string kernel_initializer = "glorot_uniform", | |||||
| string bias_initializer = "zeros") | |||||
| { | |||||
| // Special case: Conv1D will be implemented as Conv2D with H=1, so we need to add a 1-sized dimension to the kernel. | |||||
| // Lower-level logic handles the stride and dilation_rate, but the kernel_size needs to be set properly here. | |||||
| var kernel = (kernel_size == null) ? (1, 5) : (1, kernel_size.Value); | |||||
| return new Conv1D(new Conv1DArgs | |||||
| { | |||||
| Rank = 1, | |||||
| Filters = filters, | |||||
| KernelSize = kernel, | |||||
| Strides = strides == null ? 1 : strides, | |||||
| Padding = padding, | |||||
| DataFormat = data_format, | |||||
| DilationRate = dilation_rate == null ? 1 : dilation_rate, | |||||
| Groups = groups, | |||||
| UseBias = use_bias, | |||||
| Activation = GetActivationByName(activation), | |||||
| KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
| BiasInitializer = GetInitializerByName(bias_initializer) | |||||
| }); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// 2D convolution layer (e.g. spatial convolution over images). | /// 2D convolution layer (e.g. spatial convolution over images). | ||||
| /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. | /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. | ||||
| @@ -105,7 +212,7 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| Rank = 2, | Rank = 2, | ||||
| Filters = filters, | Filters = filters, | ||||
| KernelSize = kernel_size, | |||||
| KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, | |||||
| Strides = strides == null ? (1, 1) : strides, | Strides = strides == null ? (1, 1) : strides, | ||||
| Padding = padding, | Padding = padding, | ||||
| DataFormat = data_format, | DataFormat = data_format, | ||||
| @@ -150,15 +257,12 @@ namespace Tensorflow.Keras.Layers | |||||
| string activation = null, | string activation = null, | ||||
| bool use_bias = true, | bool use_bias = true, | ||||
| string kernel_initializer = "glorot_uniform", | string kernel_initializer = "glorot_uniform", | ||||
| string bias_initializer = "zeros", | |||||
| string kernel_regularizer = null, | |||||
| string bias_regularizer = null, | |||||
| string activity_regularizer = null) | |||||
| string bias_initializer = "zeros") | |||||
| => new Conv2D(new Conv2DArgs | => new Conv2D(new Conv2DArgs | ||||
| { | { | ||||
| Rank = 2, | Rank = 2, | ||||
| Filters = filters, | Filters = filters, | ||||
| KernelSize = kernel_size, | |||||
| KernelSize = (kernel_size == null) ? (5,5) : kernel_size, | |||||
| Strides = strides == null ? (1, 1) : strides, | Strides = strides == null ? (1, 1) : strides, | ||||
| Padding = padding, | Padding = padding, | ||||
| DataFormat = data_format, | DataFormat = data_format, | ||||
| @@ -204,7 +308,7 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| Rank = 2, | Rank = 2, | ||||
| Filters = filters, | Filters = filters, | ||||
| KernelSize = kernel_size, | |||||
| KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, | |||||
| Strides = strides == null ? (1, 1) : strides, | Strides = strides == null ? (1, 1) : strides, | ||||
| Padding = output_padding, | Padding = output_padding, | ||||
| DataFormat = data_format, | DataFormat = data_format, | ||||
| @@ -0,0 +1,201 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.Keras.UnitTest | |||||
| { | |||||
| [TestClass] | |||||
| public class LayersConvolutionTest : EagerModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void BasicConv1D() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv1D(filters, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(3, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 4, y.shape[1]); | |||||
| Assert.AreEqual(filters, y.shape[2]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv1D_ksize() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(3, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 2, y.shape[1]); | |||||
| Assert.AreEqual(filters, y.shape[2]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv1D_ksize_same() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv1D(filters, kernel_size: 3, padding: "same", activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(3, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1], y.shape[1]); | |||||
| Assert.AreEqual(filters, y.shape[2]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv1D_ksize_strides() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv1D(filters, kernel_size: 3, strides: 2, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(3, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 5, y.shape[1]); | |||||
| Assert.AreEqual(filters, y.shape[2]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv1D_ksize_dilations() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv1D(filters, kernel_size: 3, dilation_rate: 2, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(3, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 4, y.shape[1]); | |||||
| Assert.AreEqual(filters, y.shape[2]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv1D_ksize_dilation_same() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv1D(filters, kernel_size: 3, dilation_rate: 2, padding: "same", activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(3, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1], y.shape[1]); | |||||
| Assert.AreEqual(filters, y.shape[2]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv2D() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv2D(filters, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(1,8,8,4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(4, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 4, y.shape[1]); | |||||
| Assert.AreEqual(x.shape[2] - 4, y.shape[2]); | |||||
| Assert.AreEqual(filters, y.shape[3]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv2D_ksize() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv2D(filters, kernel_size: 3, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(1, 8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(4, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 2, y.shape[1]); | |||||
| Assert.AreEqual(x.shape[2] - 2, y.shape[2]); | |||||
| Assert.AreEqual(filters, y.shape[3]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv2D_ksize_same() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv2D(filters, kernel_size: 3, padding: "same", activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(1, 8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(4, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1], y.shape[1]); | |||||
| Assert.AreEqual(x.shape[2], y.shape[2]); | |||||
| Assert.AreEqual(filters, y.shape[3]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv2D_ksize_strides() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv2D(filters, kernel_size: 3, strides: 2, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(1, 8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(4, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 5, y.shape[1]); | |||||
| Assert.AreEqual(x.shape[2] - 5, y.shape[2]); | |||||
| Assert.AreEqual(filters, y.shape[3]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv2D_ksize_dilation() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv2D(filters, kernel_size: 3, dilation_rate: 2, activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(1, 8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(4, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1] - 4, y.shape[1]); | |||||
| Assert.AreEqual(x.shape[2] - 4, y.shape[2]); | |||||
| Assert.AreEqual(filters, y.shape[3]); | |||||
| } | |||||
| [TestMethod] | |||||
| public void BasicConv2D_ksize_dilation_same() | |||||
| { | |||||
| var filters = 8; | |||||
| var conv = keras.layers.Conv2D(filters, kernel_size: 3, dilation_rate: 2, padding: "same", activation: "linear"); | |||||
| var x = np.arange(256.0f).reshape(1, 8, 8, 4); | |||||
| var y = conv.Apply(x); | |||||
| Assert.AreEqual(4, y.shape.ndim); | |||||
| Assert.AreEqual(x.shape[0], y.shape[0]); | |||||
| Assert.AreEqual(x.shape[1], y.shape[1]); | |||||
| Assert.AreEqual(x.shape[2], y.shape[2]); | |||||
| Assert.AreEqual(filters, y.shape[3]); | |||||
| } | |||||
| } | |||||
| } | |||||