diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index c750fd16..dae50e27 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -30,19 +30,21 @@ namespace Tensorflow public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, string data_format= "NHWC", int[] dilations= null, string name = null) { - if (dilations == null) - dilations = new[] { 1, 1, 1, 1 }; - - return gen_nn_ops.conv2d(new Conv2dParams + var parameters = new Conv2dParams { Input = input, Filter = filter, Strides = strides, + Padding = padding, UseCudnnOnGpu = use_cudnn_on_gpu, DataFormat = data_format, - Dilations = dilations, Name = name - }); + }; + + if (dilations != null) + parameters.Dilations = dilations; + + return gen_nn_ops.conv2d(parameters); } /// diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs index d487dbee..a5a70f92 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs @@ -158,7 +158,6 @@ namespace TensorFlowNET.Examples.ImageProcess layer += b; return tf.nn.relu(layer); }); - } /// @@ -195,7 +194,7 @@ namespace TensorFlowNET.Examples.ImageProcess }); } - private Tensor weight_variable(string name, int[] shape) + private RefVariable weight_variable(string name, int[] shape) { var initer = tf.truncated_normal_initializer(stddev: 0.01f); return tf.get_variable(name, @@ -210,7 +209,7 @@ namespace TensorFlowNET.Examples.ImageProcess /// /// /// - private Tensor bias_variable(string name, int[] shape) + private RefVariable bias_variable(string name, int[] shape) { var initial = tf.constant(0f, shape: shape, dtype: tf.float32); return tf.get_variable(name,