|
|
|
@@ -3,6 +3,7 @@ using System.Collections.Generic; |
|
|
|
using System.Text;
|
|
|
|
using NumSharp;
|
|
|
|
using Tensorflow;
|
|
|
|
using TensorFlowNET.Examples.Utility;
|
|
|
|
|
|
|
|
namespace TensorFlowNET.Examples
|
|
|
|
{
|
|
|
|
@@ -15,6 +16,7 @@ namespace TensorFlowNET.Examples |
|
|
|
public int Priority => 10;
|
|
|
|
public bool Enabled { get; set; } = true;
|
|
|
|
public string Name => "NN XOR";
|
|
|
|
public bool ImportGraph { get; set; } = true;
|
|
|
|
|
|
|
|
public int num_steps = 5000;
|
|
|
|
|
|
|
|
@@ -38,7 +40,7 @@ namespace TensorFlowNET.Examples |
|
|
|
|
|
|
|
// Shape [4]
|
|
|
|
var predictions = tf.sigmoid(tf.squeeze(logits));
|
|
|
|
var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)));
|
|
|
|
var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)), name:"loss");
|
|
|
|
|
|
|
|
var gs = tf.Variable(0, trainable: false);
|
|
|
|
var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs);
|
|
|
|
@@ -49,7 +51,53 @@ namespace TensorFlowNET.Examples |
|
|
|
public bool Run()
|
|
|
|
{
|
|
|
|
PrepareData();
|
|
|
|
float loss_value = 0;
|
|
|
|
if (ImportGraph)
|
|
|
|
loss_value = RunWithImportedGraph();
|
|
|
|
else
|
|
|
|
loss_value=RunWithBuiltGraph();
|
|
|
|
|
|
|
|
return loss_value < 0.0627;
|
|
|
|
}
|
|
|
|
|
|
|
|
private float RunWithImportedGraph()
|
|
|
|
{
|
|
|
|
var graph = tf.Graph().as_default();
|
|
|
|
|
|
|
|
tf.train.import_meta_graph("graph/xor.meta");
|
|
|
|
|
|
|
|
var features = graph.get_operation_by_name("Placeholder");
|
|
|
|
var labels = graph.get_operation_by_name("Placeholder_1");
|
|
|
|
Tensor loss = graph.get_operation_by_name("loss");
|
|
|
|
|
|
|
|
var init = tf.global_variables_initializer();
|
|
|
|
float loss_value = 0;
|
|
|
|
// Start tf session
|
|
|
|
with<Session>(tf.Session(graph), sess =>
|
|
|
|
{
|
|
|
|
sess.run(init);
|
|
|
|
var step = 0;
|
|
|
|
|
|
|
|
var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32);
|
|
|
|
while (step < num_steps)
|
|
|
|
{
|
|
|
|
// original python:
|
|
|
|
//_, step, loss_value = sess.run(
|
|
|
|
// [train_op, gs, loss],
|
|
|
|
// feed_dict={features: xy, labels: y_}
|
|
|
|
// )
|
|
|
|
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
|
|
|
|
step++;
|
|
|
|
if (step % 1000 == 0)
|
|
|
|
Console.WriteLine($"Step {step} loss: {loss_value}");
|
|
|
|
}
|
|
|
|
Console.WriteLine($"Final loss: {loss_value}");
|
|
|
|
});
|
|
|
|
return loss_value;
|
|
|
|
}
|
|
|
|
|
|
|
|
private float RunWithBuiltGraph()
|
|
|
|
{
|
|
|
|
var graph = tf.Graph().as_default();
|
|
|
|
|
|
|
|
var features = tf.placeholder(tf.float32, new TensorShape(4, 2));
|
|
|
|
@@ -76,12 +124,12 @@ namespace TensorFlowNET.Examples |
|
|
|
// )
|
|
|
|
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
|
|
|
|
step++;
|
|
|
|
if (step%1000==0)
|
|
|
|
if (step % 1000 == 0)
|
|
|
|
Console.WriteLine($"Step {step} loss: {loss_value}");
|
|
|
|
}
|
|
|
|
Console.WriteLine($"Final loss: {loss_value}");
|
|
|
|
});
|
|
|
|
return loss_value < 0.0627;
|
|
|
|
return loss_value;
|
|
|
|
}
|
|
|
|
|
|
|
|
public void PrepareData()
|
|
|
|
@@ -93,6 +141,10 @@ namespace TensorFlowNET.Examples |
|
|
|
{0, 0 },
|
|
|
|
{0, 1 }
|
|
|
|
};
|
|
|
|
|
|
|
|
// download graph meta data
|
|
|
|
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta";
|
|
|
|
Web.Download(url, "graph", "kmeans.meta");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|