From 6b154dd1bbe31869ecd046321c8ce11c682179ba Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 20 Mar 2019 09:01:08 -0500 Subject: [PATCH] tf.log, tf.nn.softmax --- src/TensorFlowNET.Core/APIs/tf.math.cs | 9 +++-- src/TensorFlowNET.Core/APIs/tf.nn.cs | 3 ++ .../Operations/NnOps/gen_nn_ops.cs | 10 +++++ .../Operations/array_ops.py.cs | 2 +- .../Operations/math_ops.py.cs | 2 +- .../LogisticRegression.cs | 38 +++++++++++++++++-- .../TensorFlowNET.Examples/Utility/DataSet.cs | 10 +++++ .../Utility/Datasets.cs | 25 ++++++++++++ .../Utility/MnistDataSet.cs | 8 +++- 9 files changed, 97 insertions(+), 10 deletions(-) create mode 100644 test/TensorFlowNET.Examples/Utility/Datasets.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 0534b645..e01da00c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -18,6 +18,9 @@ namespace Tensorflow public static Tensor subtract(Tensor x, T[] y, string name = null) where T : struct => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); + public static Tensor log(Tensor x, string name = null) + => gen_math_ops.log(x, name); + public static Tensor multiply(Tensor x, Tensor y) => gen_math_ops.mul(x, y); @@ -33,11 +36,11 @@ namespace Tensorflow /// /// /// - public static Tensor reduce_sum(Tensor input, int[] axis = null) + public static Tensor reduce_sum(Tensor input, int[] axis = null, int? reduction_indices = null) => math_ops.reduce_sum(input); - public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) - => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name); + public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) + => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) => math_ops.cast(x, dtype, name); diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 1288508c..8a1b648e 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -56,6 +56,9 @@ namespace Tensorflow }); } + public static Tensor softmax(Tensor logits, int axis = -1, string name = null) + => gen_nn_ops.softmax(logits, name); + public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index dd99a1ff..72ca1c1b 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -146,6 +146,16 @@ namespace Tensorflow.Operations return _op.outputs[0]; } + public static Tensor softmax(Tensor logits, string name = null) + { + var _op = _op_def_lib._apply_op_helper("Softmax", name: name, args: new + { + logits + }); + + return _op.outputs[0]; + } + /// /// Computes softmax cross entropy cost and gradients to backpropagate. /// diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index dc793fea..e6b3671c 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -42,7 +42,7 @@ namespace Tensorflow else { tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); - var c = constant_op.constant(0); + var c = constant_op.constant(0, dtype: dtype); return gen_array_ops.fill(tShape, c, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.py.cs b/src/TensorFlowNET.Core/Operations/math_ops.py.cs index 4741214e..fd2a0644 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.py.cs @@ -38,7 +38,7 @@ namespace Tensorflow /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`. /// If true, retains reduced dimensions with length 1. /// A name for the operation (optional). - public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) + public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) { var r = _ReductionDims(input_tensor, axis); if (axis == null) diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index b9d01c6d..71760404 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp.Core; +using System; using System.Collections.Generic; using System.Text; using Tensorflow; @@ -13,6 +14,11 @@ namespace TensorFlowNET.Examples /// public class LogisticRegression : Python, IExample { + private float learning_rate = 0.01f; + private int training_epochs = 25; + private int batch_size = 100; + private int display_step = 1; + public void Run() { PrepareData(); @@ -20,8 +26,34 @@ namespace TensorFlowNET.Examples private void PrepareData() { - MnistDataSet.read_data_sets("logistic_regression", one_hot: true); - + var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); + + // tf Graph Input + var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 + var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes + + // Set model weights + var W = tf.Variable(tf.zeros(new Shape(784, 10))); + var b = tf.Variable(tf.zeros(new Shape(10))); + + // Construct model + var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax + + // Minimize error using cross entropy + var sum = -tf.reduce_sum(y * tf.log(pred), reduction_indices: 1); + var cost = tf.reduce_mean(sum); + + // Gradient Descent + var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); + + // Initialize the variables (i.e. assign their default value) + var init = tf.global_variables_initializer(); + + with(tf.Session(), sess => + { + // Run the initializer + sess.run(init); + }); } } } diff --git a/test/TensorFlowNET.Examples/Utility/DataSet.cs b/test/TensorFlowNET.Examples/Utility/DataSet.cs index 1005aec3..7ace7b94 100644 --- a/test/TensorFlowNET.Examples/Utility/DataSet.cs +++ b/test/TensorFlowNET.Examples/Utility/DataSet.cs @@ -9,12 +9,22 @@ namespace TensorFlowNET.Examples.Utility public class DataSet { private int _num_examples; + private int _epochs_completed; + private int _index_in_epoch; + private NDArray _images; + private NDArray _labels; public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) { _num_examples = images.shape[0]; images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); + images.astype(dtype.as_numpy_datatype()); images = np.multiply(images, 1.0f / 255.0f); + + _images = images; + _labels = labels; + _epochs_completed = 0; + _index_in_epoch = 0; } } } diff --git a/test/TensorFlowNET.Examples/Utility/Datasets.cs b/test/TensorFlowNET.Examples/Utility/Datasets.cs new file mode 100644 index 00000000..660e40db --- /dev/null +++ b/test/TensorFlowNET.Examples/Utility/Datasets.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace TensorFlowNET.Examples.Utility +{ + public class Datasets + { + private DataSet _train; + public DataSet train => _train; + + private DataSet _validation; + public DataSet validation => _validation; + + private DataSet _test; + public DataSet test => _test; + + public Datasets(DataSet train, DataSet validation, DataSet test) + { + _train = train; + _validation = validation; + _test = test; + } + } +} diff --git a/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs index 05ad2970..f54fd95c 100644 --- a/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs +++ b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs @@ -18,9 +18,9 @@ namespace TensorFlowNET.Examples.Utility private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; - public static void read_data_sets(string train_dir, + public static Datasets read_data_sets(string train_dir, bool one_hot = false, - TF_DataType dtype = TF_DataType.DtInvalid, + TF_DataType dtype = TF_DataType.TF_FLOAT, bool reshape = true, int validation_size = 5000, string source_url = DEFAULT_SOURCE_URL) @@ -48,6 +48,10 @@ namespace TensorFlowNET.Examples.Utility train_labels = train_labels[np.arange(validation_size, end)]; var train = new DataSet(train_images, train_labels, dtype, reshape); + var validation = new DataSet(validation_images, validation_labels, dtype, reshape); + var test = new DataSet(test_images, test_labels, dtype, reshape); + + return new Datasets(train, validation, test); } public static NDArray extract_images(string file)