diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 44703efb..c750fd16 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -27,6 +27,24 @@ namespace Tensorflow
{
public static class nn
{
+ public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
+ string data_format= "NHWC", int[] dilations= null, string name = null)
+ {
+ if (dilations == null)
+ dilations = new[] { 1, 1, 1, 1 };
+
+ return gen_nn_ops.conv2d(new Conv2dParams
+ {
+ Input = input,
+ Filter = filter,
+ Strides = strides,
+ UseCudnnOnGpu = use_cudnn_on_gpu,
+ DataFormat = data_format,
+ Dilations = dilations,
+ Name = name
+ });
+ }
+
///
/// Computes dropout.
///
@@ -90,7 +108,10 @@ namespace Tensorflow
is_training: is_training,
name: name);
- public static IPoolFunction max_pool => new MaxPoolFunction();
+ public static IPoolFunction max_pool_fn => new MaxPoolFunction();
+
+ public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
+ => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);
diff --git a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs
index 649c1a33..bdb577e0 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers
int[] strides,
string padding = "valid",
string data_format = null,
- string name = null) : base(nn.max_pool, pool_size,
+ string name = null) : base(nn.max_pool_fn, pool_size,
strides,
padding: padding,
data_format: data_format,
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index fddbe40a..492e3ca7 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -118,6 +118,38 @@ namespace Tensorflow
return _softmax(logits, gen_nn_ops.log_softmax, axis, name);
}
+ ///
+ /// Performs the max pooling on the input.
+ ///
+ /// A 4-D `Tensor` of the format specified by `data_format`.
+ ///
+ /// A list or tuple of 4 ints. The size of the window for each dimension
+ /// of the input tensor.
+ ///
+ ///
+ /// A list or tuple of 4 ints. The stride of the sliding window for
+ /// each dimension of the input tensor.
+ ///
+ /// A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ /// A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
+ /// Optional name for the operation.
+ ///
+ public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
+ {
+ return with(ops.name_scope(name, "MaxPool", value), scope =>
+ {
+ name = scope;
+ value = ops.convert_to_tensor(value, name: "input");
+ return gen_nn_ops.max_pool(
+ value,
+ ksize: ksize,
+ strides: strides,
+ padding: padding,
+ data_format: data_format,
+ name: name);
+ });
+ }
+
public static Tensor _softmax(Tensor logits, Func compute_op, int dim = -1, string name = null)
{
logits = ops.convert_to_tensor(logits);
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index d144bfb6..732e0a4e 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -24,6 +24,16 @@ namespace Tensorflow
}
+ public TensorShape this[Slice slice]
+ {
+ get
+ {
+ return new TensorShape(Dimensions.Skip(slice.Start.Value)
+ .Take(slice.Length.Value)
+ .ToArray());
+ }
+ }
+
///
/// Returns True iff `self` is fully defined in every dimension.
///
@@ -38,6 +48,9 @@ namespace Tensorflow
throw new NotImplementedException("TensorShape is_compatible_with");
}
+ public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
+ public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
+ public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
}
}
diff --git a/test/TensorFlowNET.Examples/IExample.cs b/test/TensorFlowNET.Examples/IExample.cs
index 8b07d6ed..d9826bae 100644
--- a/test/TensorFlowNET.Examples/IExample.cs
+++ b/test/TensorFlowNET.Examples/IExample.cs
@@ -1,4 +1,20 @@
-using System;
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
index 0abfaf2a..d487dbee 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
@@ -65,7 +66,7 @@ namespace TensorFlowNET.Examples.ImageProcess
Tensor x, y;
- Tensor loss, accuracy;
+ Tensor loss, accuracy, cls_prediction;
Operation optimizer;
int display_freq = 100;
@@ -90,47 +91,148 @@ namespace TensorFlowNET.Examples.ImageProcess
{
var graph = new Graph().as_default();
- // Placeholders for inputs (x) and outputs(y)
- x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X");
- y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
+ with(tf.name_scope("Input"), delegate
+ {
+ // Placeholders for inputs (x) and outputs(y)
+ x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X");
+ y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
+ });
- // Create a fully-connected layer with h1 nodes as hidden layer
- var fc1 = fc_layer(x, h1, "FC1", use_relu: true);
- // Create a fully-connected layer with n_classes nodes as output layer
+ var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1");
+ var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1");
+ var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2");
+ var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2");
+ var layer_flat = flatten_layer(pool2);
+ var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true);
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
- // Define the loss function, optimizer, and accuracy
- var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);
- loss = tf.reduce_mean(logits, name: "loss");
- optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
- var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
- // Network predictions
- var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
+ with(tf.variable_scope("Train"), delegate
+ {
+ with(tf.variable_scope("Loss"), delegate
+ {
+ loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits), name: "loss");
+ });
+
+ with(tf.variable_scope("Optimizer"), delegate
+ {
+ optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
+ });
+
+ with(tf.variable_scope("Accuracy"), delegate
+ {
+ var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
+ });
+
+ with(tf.variable_scope("Prediction"), delegate
+ {
+ cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
+ });
+ });
return graph;
}
- private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
+ ///
+ /// Create a 2D convolution layer
+ ///
+ /// input from previous layer
+ /// size of each filter
+ /// number of filters(or output feature maps)
+ /// filter stride
+ /// layer name
+ /// The output array
+ private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name)
+ {
+ return with(tf.variable_scope(name), delegate {
+
+ var num_in_channel = x.shape[x.NDims - 1];
+ var shape = new[] { filter_size, filter_size, num_in_channel, num_filters };
+ var W = weight_variable("W", shape);
+ // var tf.summary.histogram("weight", W);
+ var b = bias_variable("b", new[] { num_filters });
+ // tf.summary.histogram("bias", b);
+ var layer = tf.nn.conv2d(x, W,
+ strides: new[] { 1, stride, stride, 1 },
+ padding: "SAME");
+ layer += b;
+ return tf.nn.relu(layer);
+ });
+
+ }
+
+ ///
+ /// Create a max pooling layer
+ ///
+ /// input to max-pooling layer
+ /// size of the max-pooling filter
+ /// stride of the max-pooling filter
+ /// layer name
+ /// The output array
+ private Tensor max_pool(Tensor x, int ksize, int stride, string name)
{
- var in_dim = x.shape[1];
+ return tf.nn.max_pool(x,
+ ksize: new[] { 1, ksize, ksize, 1 },
+ strides: new[] { 1, stride, stride, 1 },
+ padding: "SAME",
+ name: name);
+ }
+ ///
+ /// Flattens the output of the convolutional layer to be fed into fully-connected layer
+ ///
+ /// input array
+ /// flattened array
+ private Tensor flatten_layer(Tensor layer)
+ {
+ return with(tf.variable_scope("Flatten_layer"), delegate
+ {
+ var layer_shape = layer.TensorShape;
+ var num_features = layer_shape[new Slice(1, 4)].Size;
+ var layer_flat = tf.reshape(layer, new[] { -1, num_features });
+
+ return layer_flat;
+ });
+ }
+
+ private Tensor weight_variable(string name, int[] shape)
+ {
var initer = tf.truncated_normal_initializer(stddev: 0.01f);
- var W = tf.get_variable("W_" + name,
- dtype: tf.float32,
- shape: (in_dim, num_units),
- initializer: initer);
+ return tf.get_variable(name,
+ dtype: tf.float32,
+ shape: shape,
+ initializer: initer);
+ }
- var initial = tf.constant(0f, num_units);
- var b = tf.get_variable("b_" + name,
- dtype: tf.float32,
- initializer: initial);
+ ///
+ /// Create a bias variable with appropriate initialization
+ ///
+ ///
+ ///
+ ///
+ private Tensor bias_variable(string name, int[] shape)
+ {
+ var initial = tf.constant(0f, shape: shape, dtype: tf.float32);
+ return tf.get_variable(name,
+ dtype: tf.float32,
+ initializer: initial);
+ }
+
+ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
+ {
+ return with(tf.variable_scope(name), delegate
+ {
+ var in_dim = x.shape[1];
- var layer = tf.matmul(x, W) + b;
- if (use_relu)
- layer = tf.nn.relu(layer);
+ var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units });
+ var b = bias_variable("b_" + name, new[] { num_units });
- return layer;
+ var layer = tf.matmul(x, W) + b;
+ if (use_relu)
+ layer = tf.nn.relu(layer);
+
+ return layer;
+ });
}
public Graph ImportGraph() => throw new NotImplementedException();
diff --git a/test/TensorFlowNET.Examples/Utility/Compress.cs b/test/TensorFlowNET.Examples/Utility/Compress.cs
index bc38434b..95eb0ddf 100644
--- a/test/TensorFlowNET.Examples/Utility/Compress.cs
+++ b/test/TensorFlowNET.Examples/Utility/Compress.cs
@@ -1,4 +1,20 @@
-using ICSharpCode.SharpZipLib.Core;
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using ICSharpCode.SharpZipLib.Core;
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;
using System;
diff --git a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs b/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs
index d24ea87d..0825a702 100644
--- a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs
+++ b/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs
@@ -1,4 +1,20 @@
-using NumSharp;
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
diff --git a/test/TensorFlowNET.Examples/Utility/Web.cs b/test/TensorFlowNET.Examples/Utility/Web.cs
index e2155e93..95e2c762 100644
--- a/test/TensorFlowNET.Examples/Utility/Web.cs
+++ b/test/TensorFlowNET.Examples/Utility/Web.cs
@@ -1,4 +1,20 @@
-using System;
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;