| @@ -27,6 +27,24 @@ namespace Tensorflow | |||
| { | |||
| public static class nn | |||
| { | |||
| public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||
| string data_format= "NHWC", int[] dilations= null, string name = null) | |||
| { | |||
| if (dilations == null) | |||
| dilations = new[] { 1, 1, 1, 1 }; | |||
| return gen_nn_ops.conv2d(new Conv2dParams | |||
| { | |||
| Input = input, | |||
| Filter = filter, | |||
| Strides = strides, | |||
| UseCudnnOnGpu = use_cudnn_on_gpu, | |||
| DataFormat = data_format, | |||
| Dilations = dilations, | |||
| Name = name | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Computes dropout. | |||
| /// </summary> | |||
| @@ -90,7 +108,10 @@ namespace Tensorflow | |||
| is_training: is_training, | |||
| name: name); | |||
| public static IPoolFunction max_pool => new MaxPoolFunction(); | |||
| public static IPoolFunction max_pool_fn => new MaxPoolFunction(); | |||
| public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | |||
| => nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name); | |||
| public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) | |||
| => gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); | |||
| @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers | |||
| int[] strides, | |||
| string padding = "valid", | |||
| string data_format = null, | |||
| string name = null) : base(nn.max_pool, pool_size, | |||
| string name = null) : base(nn.max_pool_fn, pool_size, | |||
| strides, | |||
| padding: padding, | |||
| data_format: data_format, | |||
| @@ -118,6 +118,38 @@ namespace Tensorflow | |||
| return _softmax(logits, gen_nn_ops.log_softmax, axis, name); | |||
| } | |||
| /// <summary> | |||
| /// Performs the max pooling on the input. | |||
| /// </summary> | |||
| /// <param name="value">A 4-D `Tensor` of the format specified by `data_format`.</param> | |||
| /// <param name="ksize"> | |||
| /// A list or tuple of 4 ints. The size of the window for each dimension | |||
| /// of the input tensor. | |||
| /// </param> | |||
| /// <param name="strides"> | |||
| /// A list or tuple of 4 ints. The stride of the sliding window for | |||
| /// each dimension of the input tensor. | |||
| /// </param> | |||
| /// <param name="padding">A string, either `'VALID'` or `'SAME'`. The padding algorithm.</param> | |||
| /// <param name="data_format">A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.</param> | |||
| /// <param name="name">Optional name for the operation.</param> | |||
| /// <returns></returns> | |||
| public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) | |||
| { | |||
| return with(ops.name_scope(name, "MaxPool", value), scope => | |||
| { | |||
| name = scope; | |||
| value = ops.convert_to_tensor(value, name: "input"); | |||
| return gen_nn_ops.max_pool( | |||
| value, | |||
| ksize: ksize, | |||
| strides: strides, | |||
| padding: padding, | |||
| data_format: data_format, | |||
| name: name); | |||
| }); | |||
| } | |||
| public static Tensor _softmax(Tensor logits, Func<Tensor, string, Tensor> compute_op, int dim = -1, string name = null) | |||
| { | |||
| logits = ops.convert_to_tensor(logits); | |||
| @@ -24,6 +24,16 @@ namespace Tensorflow | |||
| } | |||
| public TensorShape this[Slice slice] | |||
| { | |||
| get | |||
| { | |||
| return new TensorShape(Dimensions.Skip(slice.Start.Value) | |||
| .Take(slice.Length.Value) | |||
| .ToArray()); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Returns True iff `self` is fully defined in every dimension. | |||
| /// </summary> | |||
| @@ -38,6 +48,9 @@ namespace Tensorflow | |||
| throw new NotImplementedException("TensorShape is_compatible_with"); | |||
| } | |||
| public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | |||
| public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | |||
| public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | |||
| public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | |||
| } | |||
| } | |||
| @@ -1,4 +1,20 @@ | |||
| using System; | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| @@ -65,7 +66,7 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
| Tensor x, y; | |||
| Tensor loss, accuracy; | |||
| Tensor loss, accuracy, cls_prediction; | |||
| Operation optimizer; | |||
| int display_freq = 100; | |||
| @@ -90,47 +91,148 @@ namespace TensorFlowNET.Examples.ImageProcess | |||
| { | |||
| var graph = new Graph().as_default(); | |||
| // Placeholders for inputs (x) and outputs(y) | |||
| x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); | |||
| y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); | |||
| with(tf.name_scope("Input"), delegate | |||
| { | |||
| // Placeholders for inputs (x) and outputs(y) | |||
| x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X"); | |||
| y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); | |||
| }); | |||
| // Create a fully-connected layer with h1 nodes as hidden layer | |||
| var fc1 = fc_layer(x, h1, "FC1", use_relu: true); | |||
| // Create a fully-connected layer with n_classes nodes as output layer | |||
| var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1"); | |||
| var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1"); | |||
| var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2"); | |||
| var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2"); | |||
| var layer_flat = flatten_layer(pool2); | |||
| var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true); | |||
| var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); | |||
| // Define the loss function, optimizer, and accuracy | |||
| var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); | |||
| loss = tf.reduce_mean(logits, name: "loss"); | |||
| optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); | |||
| var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); | |||
| accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); | |||
| // Network predictions | |||
| var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); | |||
| with(tf.variable_scope("Train"), delegate | |||
| { | |||
| with(tf.variable_scope("Loss"), delegate | |||
| { | |||
| loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits), name: "loss"); | |||
| }); | |||
| with(tf.variable_scope("Optimizer"), delegate | |||
| { | |||
| optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); | |||
| }); | |||
| with(tf.variable_scope("Accuracy"), delegate | |||
| { | |||
| var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); | |||
| accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); | |||
| }); | |||
| with(tf.variable_scope("Prediction"), delegate | |||
| { | |||
| cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); | |||
| }); | |||
| }); | |||
| return graph; | |||
| } | |||
| private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) | |||
| /// <summary> | |||
| /// Create a 2D convolution layer | |||
| /// </summary> | |||
| /// <param name="x">input from previous layer</param> | |||
| /// <param name="filter_size">size of each filter</param> | |||
| /// <param name="num_filters">number of filters(or output feature maps)</param> | |||
| /// <param name="stride">filter stride</param> | |||
| /// <param name="name">layer name</param> | |||
| /// <returns>The output array</returns> | |||
| private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name) | |||
| { | |||
| return with(tf.variable_scope(name), delegate { | |||
| var num_in_channel = x.shape[x.NDims - 1]; | |||
| var shape = new[] { filter_size, filter_size, num_in_channel, num_filters }; | |||
| var W = weight_variable("W", shape); | |||
| // var tf.summary.histogram("weight", W); | |||
| var b = bias_variable("b", new[] { num_filters }); | |||
| // tf.summary.histogram("bias", b); | |||
| var layer = tf.nn.conv2d(x, W, | |||
| strides: new[] { 1, stride, stride, 1 }, | |||
| padding: "SAME"); | |||
| layer += b; | |||
| return tf.nn.relu(layer); | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Create a max pooling layer | |||
| /// </summary> | |||
| /// <param name="x">input to max-pooling layer</param> | |||
| /// <param name="ksize">size of the max-pooling filter</param> | |||
| /// <param name="stride">stride of the max-pooling filter</param> | |||
| /// <param name="name">layer name</param> | |||
| /// <returns>The output array</returns> | |||
| private Tensor max_pool(Tensor x, int ksize, int stride, string name) | |||
| { | |||
| var in_dim = x.shape[1]; | |||
| return tf.nn.max_pool(x, | |||
| ksize: new[] { 1, ksize, ksize, 1 }, | |||
| strides: new[] { 1, stride, stride, 1 }, | |||
| padding: "SAME", | |||
| name: name); | |||
| } | |||
| /// <summary> | |||
| /// Flattens the output of the convolutional layer to be fed into fully-connected layer | |||
| /// </summary> | |||
| /// <param name="layer">input array</param> | |||
| /// <returns>flattened array</returns> | |||
| private Tensor flatten_layer(Tensor layer) | |||
| { | |||
| return with(tf.variable_scope("Flatten_layer"), delegate | |||
| { | |||
| var layer_shape = layer.TensorShape; | |||
| var num_features = layer_shape[new Slice(1, 4)].Size; | |||
| var layer_flat = tf.reshape(layer, new[] { -1, num_features }); | |||
| return layer_flat; | |||
| }); | |||
| } | |||
| private Tensor weight_variable(string name, int[] shape) | |||
| { | |||
| var initer = tf.truncated_normal_initializer(stddev: 0.01f); | |||
| var W = tf.get_variable("W_" + name, | |||
| dtype: tf.float32, | |||
| shape: (in_dim, num_units), | |||
| initializer: initer); | |||
| return tf.get_variable(name, | |||
| dtype: tf.float32, | |||
| shape: shape, | |||
| initializer: initer); | |||
| } | |||
| var initial = tf.constant(0f, num_units); | |||
| var b = tf.get_variable("b_" + name, | |||
| dtype: tf.float32, | |||
| initializer: initial); | |||
| /// <summary> | |||
| /// Create a bias variable with appropriate initialization | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <param name="shape"></param> | |||
| /// <returns></returns> | |||
| private Tensor bias_variable(string name, int[] shape) | |||
| { | |||
| var initial = tf.constant(0f, shape: shape, dtype: tf.float32); | |||
| return tf.get_variable(name, | |||
| dtype: tf.float32, | |||
| initializer: initial); | |||
| } | |||
| private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) | |||
| { | |||
| return with(tf.variable_scope(name), delegate | |||
| { | |||
| var in_dim = x.shape[1]; | |||
| var layer = tf.matmul(x, W) + b; | |||
| if (use_relu) | |||
| layer = tf.nn.relu(layer); | |||
| var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units }); | |||
| var b = bias_variable("b_" + name, new[] { num_units }); | |||
| return layer; | |||
| var layer = tf.matmul(x, W) + b; | |||
| if (use_relu) | |||
| layer = tf.nn.relu(layer); | |||
| return layer; | |||
| }); | |||
| } | |||
| public Graph ImportGraph() => throw new NotImplementedException(); | |||
| @@ -1,4 +1,20 @@ | |||
| using ICSharpCode.SharpZipLib.Core; | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using ICSharpCode.SharpZipLib.Core; | |||
| using ICSharpCode.SharpZipLib.GZip; | |||
| using ICSharpCode.SharpZipLib.Tar; | |||
| using System; | |||
| @@ -1,4 +1,20 @@ | |||
| using NumSharp; | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| @@ -1,4 +1,20 @@ | |||
| using System; | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||