Browse Source

tf.log, tf.nn.softmax

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
6b154dd1bb
9 changed files with 97 additions and 10 deletions
  1. +6
    -3
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  3. +10
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  6. +35
    -3
      test/TensorFlowNET.Examples/LogisticRegression.cs
  7. +10
    -0
      test/TensorFlowNET.Examples/Utility/DataSet.cs
  8. +25
    -0
      test/TensorFlowNET.Examples/Utility/Datasets.cs
  9. +6
    -2
      test/TensorFlowNET.Examples/Utility/MnistDataSet.cs

+ 6
- 3
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -18,6 +18,9 @@ namespace Tensorflow
public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct
=> gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name);

public static Tensor log(Tensor x, string name = null)
=> gen_math_ops.log(x, name);

public static Tensor multiply(Tensor x, Tensor y)
=> gen_math_ops.mul(x, y);

@@ -33,11 +36,11 @@ namespace Tensorflow
/// <param name="input"></param>
/// <param name="axis"></param>
/// <returns></returns>
public static Tensor reduce_sum(Tensor input, int[] axis = null)
public static Tensor reduce_sum(Tensor input, int[] axis = null, int? reduction_indices = null)
=> math_ops.reduce_sum(input);

public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name);
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);

public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> math_ops.cast(x, dtype, name);


+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -56,6 +56,9 @@ namespace Tensorflow
});
}

public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
=> gen_nn_ops.softmax(logits, name);

public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
}


+ 10
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -146,6 +146,16 @@ namespace Tensorflow.Operations
return _op.outputs[0];
}

public static Tensor softmax(Tensor logits, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Softmax", name: name, args: new
{
logits
});

return _op.outputs[0];
}

/// <summary>
/// Computes softmax cross entropy cost and gradients to backpropagate.
/// </summary>


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow
else
{
tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape());
var c = constant_op.constant(0);
var c = constant_op.constant(0, dtype: dtype);
return gen_array_ops.fill(tShape, c, name: name);
}
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param>
/// <param name="keepdims"> If true, retains reduced dimensions with length 1.</param>
/// <param name="name"> A name for the operation (optional).</param>
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
{
var r = _ReductionDims(input_tensor, axis);
if (axis == null)


+ 35
- 3
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;
@@ -13,6 +14,11 @@ namespace TensorFlowNET.Examples
/// </summary>
public class LogisticRegression : Python, IExample
{
private float learning_rate = 0.01f;
private int training_epochs = 25;
private int batch_size = 100;
private int display_step = 1;

public void Run()
{
PrepareData();
@@ -20,8 +26,34 @@ namespace TensorFlowNET.Examples

private void PrepareData()
{
MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);

// tf Graph Input
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes

// Set model weights
var W = tf.Variable(tf.zeros(new Shape(784, 10)));
var b = tf.Variable(tf.zeros(new Shape(10)));

// Construct model
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax

// Minimize error using cross entropy
var sum = -tf.reduce_sum(y * tf.log(pred), reduction_indices: 1);
var cost = tf.reduce_mean(sum);

// Gradient Descent
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);

// Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer();

with(tf.Session(), sess =>
{
// Run the initializer
sess.run(init);
});
}
}
}

+ 10
- 0
test/TensorFlowNET.Examples/Utility/DataSet.cs View File

@@ -9,12 +9,22 @@ namespace TensorFlowNET.Examples.Utility
public class DataSet
{
private int _num_examples;
private int _epochs_completed;
private int _index_in_epoch;
private NDArray _images;
private NDArray _labels;

public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
{
_num_examples = images.shape[0];
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
images.astype(dtype.as_numpy_datatype());
images = np.multiply(images, 1.0f / 255.0f);

_images = images;
_labels = labels;
_epochs_completed = 0;
_index_in_epoch = 0;
}
}
}

+ 25
- 0
test/TensorFlowNET.Examples/Utility/Datasets.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace TensorFlowNET.Examples.Utility
{
public class Datasets
{
private DataSet _train;
public DataSet train => _train;

private DataSet _validation;
public DataSet validation => _validation;

private DataSet _test;
public DataSet test => _test;

public Datasets(DataSet train, DataSet validation, DataSet test)
{
_train = train;
_validation = validation;
_test = test;
}
}
}

+ 6
- 2
test/TensorFlowNET.Examples/Utility/MnistDataSet.cs View File

@@ -18,9 +18,9 @@ namespace TensorFlowNET.Examples.Utility
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public static void read_data_sets(string train_dir,
public static Datasets read_data_sets(string train_dir,
bool one_hot = false,
TF_DataType dtype = TF_DataType.DtInvalid,
TF_DataType dtype = TF_DataType.TF_FLOAT,
bool reshape = true,
int validation_size = 5000,
string source_url = DEFAULT_SOURCE_URL)
@@ -48,6 +48,10 @@ namespace TensorFlowNET.Examples.Utility
train_labels = train_labels[np.arange(validation_size, end)];

var train = new DataSet(train_images, train_labels, dtype, reshape);
var validation = new DataSet(validation_images, validation_labels, dtype, reshape);
var test = new DataSet(test_images, test_labels, dtype, reshape);

return new Datasets(train, validation, test);
}

public static NDArray extract_images(string file)


Loading…
Cancel
Save