| @@ -18,6 +18,9 @@ namespace Tensorflow | |||||
| public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct | public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct | ||||
| => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); | => gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name); | ||||
| public static Tensor log(Tensor x, string name = null) | |||||
| => gen_math_ops.log(x, name); | |||||
| public static Tensor multiply(Tensor x, Tensor y) | public static Tensor multiply(Tensor x, Tensor y) | ||||
| => gen_math_ops.mul(x, y); | => gen_math_ops.mul(x, y); | ||||
| @@ -33,11 +36,11 @@ namespace Tensorflow | |||||
| /// <param name="input"></param> | /// <param name="input"></param> | ||||
| /// <param name="axis"></param> | /// <param name="axis"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor reduce_sum(Tensor input, int[] axis = null) | |||||
| public static Tensor reduce_sum(Tensor input, int[] axis = null, int? reduction_indices = null) | |||||
| => math_ops.reduce_sum(input); | => math_ops.reduce_sum(input); | ||||
| public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | |||||
| => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name); | |||||
| public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | |||||
| => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | |||||
| public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | ||||
| => math_ops.cast(x, dtype, name); | => math_ops.cast(x, dtype, name); | ||||
| @@ -56,6 +56,9 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public static Tensor softmax(Tensor logits, int axis = -1, string name = null) | |||||
| => gen_nn_ops.softmax(logits, name); | |||||
| public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) | public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) | ||||
| => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); | => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); | ||||
| } | } | ||||
| @@ -146,6 +146,16 @@ namespace Tensorflow.Operations | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor softmax(Tensor logits, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("Softmax", name: name, args: new | |||||
| { | |||||
| logits | |||||
| }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Computes softmax cross entropy cost and gradients to backpropagate. | /// Computes softmax cross entropy cost and gradients to backpropagate. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow | |||||
| else | else | ||||
| { | { | ||||
| tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); | tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape()); | ||||
| var c = constant_op.constant(0); | |||||
| var c = constant_op.constant(0, dtype: dtype); | |||||
| return gen_array_ops.fill(tShape, c, name: name); | return gen_array_ops.fill(tShape, c, name: name); | ||||
| } | } | ||||
| } | } | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow | |||||
| /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param> | /// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param> | ||||
| /// <param name="keepdims"> If true, retains reduced dimensions with length 1.</param> | /// <param name="keepdims"> If true, retains reduced dimensions with length 1.</param> | ||||
| /// <param name="name"> A name for the operation (optional).</param> | /// <param name="name"> A name for the operation (optional).</param> | ||||
| public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null) | |||||
| public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | |||||
| { | { | ||||
| var r = _ReductionDims(input_tensor, axis); | var r = _ReductionDims(input_tensor, axis); | ||||
| if (axis == null) | if (axis == null) | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -13,6 +14,11 @@ namespace TensorFlowNET.Examples | |||||
| /// </summary> | /// </summary> | ||||
| public class LogisticRegression : Python, IExample | public class LogisticRegression : Python, IExample | ||||
| { | { | ||||
| private float learning_rate = 0.01f; | |||||
| private int training_epochs = 25; | |||||
| private int batch_size = 100; | |||||
| private int display_step = 1; | |||||
| public void Run() | public void Run() | ||||
| { | { | ||||
| PrepareData(); | PrepareData(); | ||||
| @@ -20,8 +26,34 @@ namespace TensorFlowNET.Examples | |||||
| private void PrepareData() | private void PrepareData() | ||||
| { | { | ||||
| MnistDataSet.read_data_sets("logistic_regression", one_hot: true); | |||||
| var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true); | |||||
| // tf Graph Input | |||||
| var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784 | |||||
| var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes | |||||
| // Set model weights | |||||
| var W = tf.Variable(tf.zeros(new Shape(784, 10))); | |||||
| var b = tf.Variable(tf.zeros(new Shape(10))); | |||||
| // Construct model | |||||
| var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax | |||||
| // Minimize error using cross entropy | |||||
| var sum = -tf.reduce_sum(y * tf.log(pred), reduction_indices: 1); | |||||
| var cost = tf.reduce_mean(sum); | |||||
| // Gradient Descent | |||||
| var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||||
| // Initialize the variables (i.e. assign their default value) | |||||
| var init = tf.global_variables_initializer(); | |||||
| with(tf.Session(), sess => | |||||
| { | |||||
| // Run the initializer | |||||
| sess.run(init); | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -9,12 +9,22 @@ namespace TensorFlowNET.Examples.Utility | |||||
| public class DataSet | public class DataSet | ||||
| { | { | ||||
| private int _num_examples; | private int _num_examples; | ||||
| private int _epochs_completed; | |||||
| private int _index_in_epoch; | |||||
| private NDArray _images; | |||||
| private NDArray _labels; | |||||
| public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | ||||
| { | { | ||||
| _num_examples = images.shape[0]; | _num_examples = images.shape[0]; | ||||
| images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | ||||
| images.astype(dtype.as_numpy_datatype()); | |||||
| images = np.multiply(images, 1.0f / 255.0f); | images = np.multiply(images, 1.0f / 255.0f); | ||||
| _images = images; | |||||
| _labels = labels; | |||||
| _epochs_completed = 0; | |||||
| _index_in_epoch = 0; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,25 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace TensorFlowNET.Examples.Utility | |||||
| { | |||||
| public class Datasets | |||||
| { | |||||
| private DataSet _train; | |||||
| public DataSet train => _train; | |||||
| private DataSet _validation; | |||||
| public DataSet validation => _validation; | |||||
| private DataSet _test; | |||||
| public DataSet test => _test; | |||||
| public Datasets(DataSet train, DataSet validation, DataSet test) | |||||
| { | |||||
| _train = train; | |||||
| _validation = validation; | |||||
| _test = test; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -18,9 +18,9 @@ namespace TensorFlowNET.Examples.Utility | |||||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
| public static void read_data_sets(string train_dir, | |||||
| public static Datasets read_data_sets(string train_dir, | |||||
| bool one_hot = false, | bool one_hot = false, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| bool reshape = true, | bool reshape = true, | ||||
| int validation_size = 5000, | int validation_size = 5000, | ||||
| string source_url = DEFAULT_SOURCE_URL) | string source_url = DEFAULT_SOURCE_URL) | ||||
| @@ -48,6 +48,10 @@ namespace TensorFlowNET.Examples.Utility | |||||
| train_labels = train_labels[np.arange(validation_size, end)]; | train_labels = train_labels[np.arange(validation_size, end)]; | ||||
| var train = new DataSet(train_images, train_labels, dtype, reshape); | var train = new DataSet(train_images, train_labels, dtype, reshape); | ||||
| var validation = new DataSet(validation_images, validation_labels, dtype, reshape); | |||||
| var test = new DataSet(test_images, test_labels, dtype, reshape); | |||||
| return new Datasets(train, validation, test); | |||||
| } | } | ||||
| public static NDArray extract_images(string file) | public static NDArray extract_images(string file) | ||||