Browse Source

MNIST CNN not finished yet.

tags/v0.10
Oceania2018 6 years ago
parent
commit
bc4fbf10f2
9 changed files with 267 additions and 35 deletions
  1. +22
    -1
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs
  3. +32
    -0
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  4. +13
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  5. +17
    -1
      test/TensorFlowNET.Examples/IExample.cs
  6. +131
    -29
      test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs
  7. +17
    -1
      test/TensorFlowNET.Examples/Utility/Compress.cs
  8. +17
    -1
      test/TensorFlowNET.Examples/Utility/DataSetMnist.cs
  9. +17
    -1
      test/TensorFlowNET.Examples/Utility/Web.cs

+ 22
- 1
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -27,6 +27,24 @@ namespace Tensorflow
{
public static class nn
{
public static Tensor conv2d(Tensor input, RefVariable filter, int[] strides, string padding, bool use_cudnn_on_gpu = true,
string data_format= "NHWC", int[] dilations= null, string name = null)
{
if (dilations == null)
dilations = new[] { 1, 1, 1, 1 };

return gen_nn_ops.conv2d(new Conv2dParams
{
Input = input,
Filter = filter,
Strides = strides,
UseCudnnOnGpu = use_cudnn_on_gpu,
DataFormat = data_format,
Dilations = dilations,
Name = name
});
}

/// <summary>
/// Computes dropout.
/// </summary>
@@ -90,7 +108,10 @@ namespace Tensorflow
is_training: is_training,
name: name);

public static IPoolFunction max_pool => new MaxPoolFunction();
public static IPoolFunction max_pool_fn => new MaxPoolFunction();

public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);

public static Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null)
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/MaxPooling2D.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers
int[] strides,
string padding = "valid",
string data_format = null,
string name = null) : base(nn.max_pool, pool_size,
string name = null) : base(nn.max_pool_fn, pool_size,
strides,
padding: padding,
data_format: data_format,


+ 32
- 0
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -118,6 +118,38 @@ namespace Tensorflow
return _softmax(logits, gen_nn_ops.log_softmax, axis, name);
}

/// <summary>
/// Performs the max pooling on the input.
/// </summary>
/// <param name="value">A 4-D `Tensor` of the format specified by `data_format`.</param>
/// <param name="ksize">
/// A list or tuple of 4 ints. The size of the window for each dimension
/// of the input tensor.
/// </param>
/// <param name="strides">
/// A list or tuple of 4 ints. The stride of the sliding window for
/// each dimension of the input tensor.
/// </param>
/// <param name="padding">A string, either `'VALID'` or `'SAME'`. The padding algorithm.</param>
/// <param name="data_format">A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.</param>
/// <param name="name">Optional name for the operation.</param>
/// <returns></returns>
public static Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
{
return with(ops.name_scope(name, "MaxPool", value), scope =>
{
name = scope;
value = ops.convert_to_tensor(value, name: "input");
return gen_nn_ops.max_pool(
value,
ksize: ksize,
strides: strides,
padding: padding,
data_format: data_format,
name: name);
});
}

public static Tensor _softmax(Tensor logits, Func<Tensor, string, Tensor> compute_op, int dim = -1, string name = null)
{
logits = ops.convert_to_tensor(logits);


+ 13
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -24,6 +24,16 @@ namespace Tensorflow

}

public TensorShape this[Slice slice]
{
get
{
return new TensorShape(Dimensions.Skip(slice.Start.Value)
.Take(slice.Length.Value)
.ToArray());
}
}

/// <summary>
/// Returns True iff `self` is fully defined in every dimension.
/// </summary>
@@ -38,6 +48,9 @@ namespace Tensorflow
throw new NotImplementedException("TensorShape is_compatible_with");
}

public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
}
}

+ 17
- 1
test/TensorFlowNET.Examples/IExample.cs View File

@@ -1,4 +1,20 @@
using System;
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;


+ 131
- 29
test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionCNN.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
@@ -65,7 +66,7 @@ namespace TensorFlowNET.Examples.ImageProcess


Tensor x, y;
Tensor loss, accuracy;
Tensor loss, accuracy, cls_prediction;
Operation optimizer;

int display_freq = 100;
@@ -90,47 +91,148 @@ namespace TensorFlowNET.Examples.ImageProcess
{
var graph = new Graph().as_default();

// Placeholders for inputs (x) and outputs(y)
x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X");
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
with(tf.name_scope("Input"), delegate
{
// Placeholders for inputs (x) and outputs(y)
x = tf.placeholder(tf.float32, shape: (-1, img_h, img_w, n_channels), name: "X");
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
});

// Create a fully-connected layer with h1 nodes as hidden layer
var fc1 = fc_layer(x, h1, "FC1", use_relu: true);
// Create a fully-connected layer with n_classes nodes as output layer
var conv1 = conv_layer(x, filter_size1, num_filters1, stride1, name: "conv1");
var pool1 = max_pool(conv1, ksize: 2, stride: 2, name: "pool1");
var conv2 = conv_layer(pool1, filter_size2, num_filters2, stride2, name: "conv2");
var pool2 = max_pool(conv2, ksize: 2, stride: 2, name: "pool2");
var layer_flat = flatten_layer(pool2);
var fc1 = fc_layer(layer_flat, h1, "FC1", use_relu: true);
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
// Define the loss function, optimizer, and accuracy
var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);
loss = tf.reduce_mean(logits, name: "loss");
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");

// Network predictions
var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
with(tf.variable_scope("Train"), delegate
{
with(tf.variable_scope("Loss"), delegate
{
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits), name: "loss");
});

with(tf.variable_scope("Optimizer"), delegate
{
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
});

with(tf.variable_scope("Accuracy"), delegate
{
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
});

with(tf.variable_scope("Prediction"), delegate
{
cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
});
});

return graph;
}

private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
/// <summary>
/// Create a 2D convolution layer
/// </summary>
/// <param name="x">input from previous layer</param>
/// <param name="filter_size">size of each filter</param>
/// <param name="num_filters">number of filters(or output feature maps)</param>
/// <param name="stride">filter stride</param>
/// <param name="name">layer name</param>
/// <returns>The output array</returns>
private Tensor conv_layer(Tensor x, int filter_size, int num_filters, int stride, string name)
{
return with(tf.variable_scope(name), delegate {

var num_in_channel = x.shape[x.NDims - 1];
var shape = new[] { filter_size, filter_size, num_in_channel, num_filters };
var W = weight_variable("W", shape);
// var tf.summary.histogram("weight", W);
var b = bias_variable("b", new[] { num_filters });
// tf.summary.histogram("bias", b);
var layer = tf.nn.conv2d(x, W,
strides: new[] { 1, stride, stride, 1 },
padding: "SAME");
layer += b;
return tf.nn.relu(layer);
});

}

/// <summary>
/// Create a max pooling layer
/// </summary>
/// <param name="x">input to max-pooling layer</param>
/// <param name="ksize">size of the max-pooling filter</param>
/// <param name="stride">stride of the max-pooling filter</param>
/// <param name="name">layer name</param>
/// <returns>The output array</returns>
private Tensor max_pool(Tensor x, int ksize, int stride, string name)
{
var in_dim = x.shape[1];
return tf.nn.max_pool(x,
ksize: new[] { 1, ksize, ksize, 1 },
strides: new[] { 1, stride, stride, 1 },
padding: "SAME",
name: name);
}

/// <summary>
/// Flattens the output of the convolutional layer to be fed into fully-connected layer
/// </summary>
/// <param name="layer">input array</param>
/// <returns>flattened array</returns>
private Tensor flatten_layer(Tensor layer)
{
return with(tf.variable_scope("Flatten_layer"), delegate
{
var layer_shape = layer.TensorShape;
var num_features = layer_shape[new Slice(1, 4)].Size;
var layer_flat = tf.reshape(layer, new[] { -1, num_features });

return layer_flat;
});
}

private Tensor weight_variable(string name, int[] shape)
{
var initer = tf.truncated_normal_initializer(stddev: 0.01f);
var W = tf.get_variable("W_" + name,
dtype: tf.float32,
shape: (in_dim, num_units),
initializer: initer);
return tf.get_variable(name,
dtype: tf.float32,
shape: shape,
initializer: initer);
}

var initial = tf.constant(0f, num_units);
var b = tf.get_variable("b_" + name,
dtype: tf.float32,
initializer: initial);
/// <summary>
/// Create a bias variable with appropriate initialization
/// </summary>
/// <param name="name"></param>
/// <param name="shape"></param>
/// <returns></returns>
private Tensor bias_variable(string name, int[] shape)
{
var initial = tf.constant(0f, shape: shape, dtype: tf.float32);
return tf.get_variable(name,
dtype: tf.float32,
initializer: initial);
}

private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
{
return with(tf.variable_scope(name), delegate
{
var in_dim = x.shape[1];

var layer = tf.matmul(x, W) + b;
if (use_relu)
layer = tf.nn.relu(layer);
var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units });
var b = bias_variable("b_" + name, new[] { num_units });

return layer;
var layer = tf.matmul(x, W) + b;
if (use_relu)
layer = tf.nn.relu(layer);

return layer;
});
}

public Graph ImportGraph() => throw new NotImplementedException();


+ 17
- 1
test/TensorFlowNET.Examples/Utility/Compress.cs View File

@@ -1,4 +1,20 @@
using ICSharpCode.SharpZipLib.Core;
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using ICSharpCode.SharpZipLib.Core;
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;
using System;


+ 17
- 1
test/TensorFlowNET.Examples/Utility/DataSetMnist.cs View File

@@ -1,4 +1,20 @@
using NumSharp;
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;


+ 17
- 1
test/TensorFlowNET.Examples/Utility/Web.cs View File

@@ -1,4 +1,20 @@
using System;
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;


Loading…
Cancel
Save