diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 44703efb..c750fd16 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -27,6 +27,24 @@ namespace Tensorflow { public static class nn { + public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, + string data_format= "NHWC", int[] dilations= null, string name = null) + { + if (dilations == null) + dilations = new[] { 1, 1, 1, 1 }; + + return gen_nn_ops.conv2d(new Conv2dParams + { + Input = input, + Filter = filter, + Strides = strides, + UseCudnnOnGpu = use_cudnn_on_gpu, + DataFormat = data_format, + Dilations = dilations, + Name = name + }); + } + /// /// Computes dropout. /// @@ -90,7 +108,10 @@ namespace Tensorflow is_training: is_training, name: name); - public static IPoolFunction max_pool => new MaxPoolFunction(); + public static IPoolFunction max_pool_fn => new MaxPoolFunction(); + + public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) + => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); diff --git a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs index 649c1a33..bdb577e0 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers int[] strides, string padding = "valid", string data_format = null, - string name = null) : base(nn.max_pool, pool_size, + string name = null) : base(nn.max_pool_fn, pool_size, strides, padding: padding, data_format: data_format, diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index fddbe40a..492e3ca7 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -118,6 +118,38 @@ namespace Tensorflow return _softmax(logits, gen_nn_ops.log_softmax, axis, name); } + /// + /// Performs the max pooling on the input. + /// + /// A 4-D `Tensor` of the format specified by `data_format`. + /// + /// A list or tuple of 4 ints. The size of the window for each dimension + /// of the input tensor. + /// + /// + /// A list or tuple of 4 ints. The stride of the sliding window for + /// each dimension of the input tensor. + /// + /// A string, either `'VALID'` or `'SAME'`. The padding algorithm. + /// A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported. + /// Optional name for the operation. + /// + public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) + { + return with(ops.name_scope(name, "MaxPool", value), scope => + { + name = scope; + value = ops.convert_to_tensor(value, name: "input"); + return gen_nn_ops.max_pool( + value, + ksize: ksize, + strides: strides, + padding: padding, + data_format: data_format, + name: name); + }); + } + public static Tensor _softmax(Tensor logits, Func compute_op, int dim = -1, string name = null) { logits = ops.convert_to_tensor(logits); diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index d144bfb6..732e0a4e 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -24,6 +24,16 @@ namespace Tensorflow } + public TensorShape this[Slice slice] + { + get + { + return new TensorShape(Dimensions.Skip(slice.Start.Value) + .Take(slice.Length.Value) + .ToArray()); + } + } + /// /// Returns True iff `self` is fully defined in every dimension. /// @@ -38,6 +48,9 @@ namespace Tensorflow throw new NotImplementedException("TensorShape is_compatible_with"); } + public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); + public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); + public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); } } diff --git a/test/TensorFlowNET.Examples/IExample.cs b/test/TensorFlowNET.Examples/IExample.cs index 8b07d6ed..d9826bae 100644 --- a/test/TensorFlowNET.Examples/IExample.cs +++ b/test/TensorFlowNET.Examples/IExample.cs @@ -1,4 +1,20 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.Text; using Tensorflow; diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs index 0abfaf2a..d487dbee 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; using System.Text; @@ -65,7 +66,7 @@ namespace TensorFlowNET.Examples.ImageProcess Tensor x, y; - Tensor loss, accuracy; + Tensor loss, accuracy, cls_prediction; Operation optimizer; int display_freq = 100; @@ -90,47 +91,148 @@ namespace TensorFlowNET.Examples.ImageProcess { var graph = new Graph().as_default(); - // Placeholders for inputs (x) and outputs(y) - x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); - y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); + with(tf.name_scope("Input"), delegate + { + // Placeholders for inputs (x) and outputs(y) + x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X"); + y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); + }); - // Create a fully-connected layer with h1 nodes as hidden layer - var fc1 = fc_layer(x, h1, "FC1", use_relu: true); - // Create a fully-connected layer with n_classes nodes as output layer + var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1"); + var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1"); + var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2"); + var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2"); + var layer_flat = flatten_layer(pool2); + var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true); var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); - // Define the loss function, optimizer, and accuracy - var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); - loss = tf.reduce_mean(logits, name: "loss"); - optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); - var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); - accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); - // Network predictions - var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); + with(tf.variable_scope("Train"), delegate + { + with(tf.variable_scope("Loss"), delegate + { + loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits), name: "loss"); + }); + + with(tf.variable_scope("Optimizer"), delegate + { + optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); + }); + + with(tf.variable_scope("Accuracy"), delegate + { + var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); + }); + + with(tf.variable_scope("Prediction"), delegate + { + cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); + }); + }); return graph; } - private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) + /// + /// Create a 2D convolution layer + /// + /// input from previous layer + /// size of each filter + /// number of filters(or output feature maps) + /// filter stride + /// layer name + /// The output array + private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name) + { + return with(tf.variable_scope(name), delegate { + + var num_in_channel = x.shape[x.NDims - 1]; + var shape = new[] { filter_size, filter_size, num_in_channel, num_filters }; + var W = weight_variable("W", shape); + // var tf.summary.histogram("weight", W); + var b = bias_variable("b", new[] { num_filters }); + // tf.summary.histogram("bias", b); + var layer = tf.nn.conv2d(x, W, + strides: new[] { 1, stride, stride, 1 }, + padding: "SAME"); + layer += b; + return tf.nn.relu(layer); + }); + + } + + /// + /// Create a max pooling layer + /// + /// input to max-pooling layer + /// size of the max-pooling filter + /// stride of the max-pooling filter + /// layer name + /// The output array + private Tensor max_pool(Tensor x, int ksize, int stride, string name) { - var in_dim = x.shape[1]; + return tf.nn.max_pool(x, + ksize: new[] { 1, ksize, ksize, 1 }, + strides: new[] { 1, stride, stride, 1 }, + padding: "SAME", + name: name); + } + /// + /// Flattens the output of the convolutional layer to be fed into fully-connected layer + /// + /// input array + /// flattened array + private Tensor flatten_layer(Tensor layer) + { + return with(tf.variable_scope("Flatten_layer"), delegate + { + var layer_shape = layer.TensorShape; + var num_features = layer_shape[new Slice(1, 4)].Size; + var layer_flat = tf.reshape(layer, new[] { -1, num_features }); + + return layer_flat; + }); + } + + private Tensor weight_variable(string name, int[] shape) + { var initer = tf.truncated_normal_initializer(stddev: 0.01f); - var W = tf.get_variable("W_" + name, - dtype: tf.float32, - shape: (in_dim, num_units), - initializer: initer); + return tf.get_variable(name, + dtype: tf.float32, + shape: shape, + initializer: initer); + } - var initial = tf.constant(0f, num_units); - var b = tf.get_variable("b_" + name, - dtype: tf.float32, - initializer: initial); + /// + /// Create a bias variable with appropriate initialization + /// + /// + /// + /// + private Tensor bias_variable(string name, int[] shape) + { + var initial = tf.constant(0f, shape: shape, dtype: tf.float32); + return tf.get_variable(name, + dtype: tf.float32, + initializer: initial); + } + + private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) + { + return with(tf.variable_scope(name), delegate + { + var in_dim = x.shape[1]; - var layer = tf.matmul(x, W) + b; - if (use_relu) - layer = tf.nn.relu(layer); + var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units }); + var b = bias_variable("b_" + name, new[] { num_units }); - return layer; + var layer = tf.matmul(x, W) + b; + if (use_relu) + layer = tf.nn.relu(layer); + + return layer; + }); } public Graph ImportGraph() => throw new NotImplementedException(); diff --git a/test/TensorFlowNET.Examples/Utility/Compress.cs b/test/TensorFlowNET.Examples/Utility/Compress.cs index bc38434b..95eb0ddf 100644 --- a/test/TensorFlowNET.Examples/Utility/Compress.cs +++ b/test/TensorFlowNET.Examples/Utility/Compress.cs @@ -1,4 +1,20 @@ -using ICSharpCode.SharpZipLib.Core; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using ICSharpCode.SharpZipLib.Core; using ICSharpCode.SharpZipLib.GZip; using ICSharpCode.SharpZipLib.Tar; using System; diff --git a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs b/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs index d24ea87d..0825a702 100644 --- a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs +++ b/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs @@ -1,4 +1,20 @@ -using NumSharp; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; using System; using System.Collections.Generic; using System.Text; diff --git a/test/TensorFlowNET.Examples/Utility/Web.cs b/test/TensorFlowNET.Examples/Utility/Web.cs index e2155e93..95e2c762 100644 --- a/test/TensorFlowNET.Examples/Utility/Web.cs +++ b/test/TensorFlowNET.Examples/Utility/Web.cs @@ -1,4 +1,20 @@ -using System; +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; using System.Collections.Generic; using System.IO; using System.Linq;