|
|
|
@@ -16,13 +16,13 @@ namespace TensorFlowNET.Examples |
|
|
|
public int Priority => 10;
|
|
|
|
public bool Enabled { get; set; } = true;
|
|
|
|
public string Name => "NN XOR";
|
|
|
|
public bool ImportGraph { get; set; } = true;
|
|
|
|
public bool ImportGraph { get; set; } = false;
|
|
|
|
|
|
|
|
public int num_steps = 5000;
|
|
|
|
|
|
|
|
private NDArray data;
|
|
|
|
|
|
|
|
private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
|
|
|
|
private (Operation, Tensor, Tensor) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
|
|
|
|
{
|
|
|
|
var stddev = 1 / Math.Sqrt(2);
|
|
|
|
var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev ));
|
|
|
|
@@ -69,6 +69,8 @@ namespace TensorFlowNET.Examples |
|
|
|
Tensor features = graph.get_operation_by_name("Placeholder");
|
|
|
|
Tensor labels = graph.get_operation_by_name("Placeholder_1");
|
|
|
|
Tensor loss = graph.get_operation_by_name("loss");
|
|
|
|
Tensor train_op = graph.get_operation_by_name("train_op");
|
|
|
|
Tensor global_step = graph.get_operation_by_name("global_step");
|
|
|
|
|
|
|
|
var init = tf.global_variables_initializer();
|
|
|
|
float loss_value = 0;
|
|
|
|
@@ -86,7 +88,8 @@ namespace TensorFlowNET.Examples |
|
|
|
// [train_op, gs, loss],
|
|
|
|
// feed_dict={features: xy, labels: y_}
|
|
|
|
// )
|
|
|
|
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
|
|
|
|
var result = sess.run(new ITensorOrOperation[] { train_op, global_step, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
|
|
|
|
loss_value = result[2];
|
|
|
|
step++;
|
|
|
|
if (step % 1000 == 0)
|
|
|
|
Console.WriteLine($"Step {step} loss: {loss_value}");
|
|
|
|
@@ -122,7 +125,9 @@ namespace TensorFlowNET.Examples |
|
|
|
// [train_op, gs, loss],
|
|
|
|
// feed_dict={features: xy, labels: y_}
|
|
|
|
// )
|
|
|
|
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
|
|
|
|
var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
|
|
|
|
loss_value = result[2];
|
|
|
|
//step = result[1];
|
|
|
|
step++;
|
|
|
|
if (step % 1000 == 0)
|
|
|
|
Console.WriteLine($"Step {step} loss: {loss_value}");
|
|
|
|
|