Browse Source

Xor is working without imported graph!

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
0b99f3cd79
1 changed files with 9 additions and 4 deletions
  1. +9
    -4
      test/TensorFlowNET.Examples/NeuralNetXor.cs

+ 9
- 4
test/TensorFlowNET.Examples/NeuralNetXor.cs View File

@@ -16,13 +16,13 @@ namespace TensorFlowNET.Examples
public int Priority => 10; public int Priority => 10;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "NN XOR"; public string Name => "NN XOR";
public bool ImportGraph { get; set; } = true;
public bool ImportGraph { get; set; } = false;
public int num_steps = 5000; public int num_steps = 5000;
private NDArray data; private NDArray data;
private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
private (Operation, Tensor, Tensor) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
{ {
var stddev = 1 / Math.Sqrt(2); var stddev = 1 / Math.Sqrt(2);
var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev )); var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev ));
@@ -69,6 +69,8 @@ namespace TensorFlowNET.Examples
Tensor features = graph.get_operation_by_name("Placeholder"); Tensor features = graph.get_operation_by_name("Placeholder");
Tensor labels = graph.get_operation_by_name("Placeholder_1"); Tensor labels = graph.get_operation_by_name("Placeholder_1");
Tensor loss = graph.get_operation_by_name("loss"); Tensor loss = graph.get_operation_by_name("loss");
Tensor train_op = graph.get_operation_by_name("train_op");
Tensor global_step = graph.get_operation_by_name("global_step");
var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();
float loss_value = 0; float loss_value = 0;
@@ -86,7 +88,8 @@ namespace TensorFlowNET.Examples
// [train_op, gs, loss], // [train_op, gs, loss],
// feed_dict={features: xy, labels: y_} // feed_dict={features: xy, labels: y_}
// ) // )
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
var result = sess.run(new ITensorOrOperation[] { train_op, global_step, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
loss_value = result[2];
step++; step++;
if (step % 1000 == 0) if (step % 1000 == 0)
Console.WriteLine($"Step {step} loss: {loss_value}"); Console.WriteLine($"Step {step} loss: {loss_value}");
@@ -122,7 +125,9 @@ namespace TensorFlowNET.Examples
// [train_op, gs, loss], // [train_op, gs, loss],
// feed_dict={features: xy, labels: y_} // feed_dict={features: xy, labels: y_}
// ) // )
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_));
loss_value = result[2];
//step = result[1];
step++; step++;
if (step % 1000 == 0) if (step % 1000 == 0)
Console.WriteLine($"Step {step} loss: {loss_value}"); Console.WriteLine($"Step {step} loss: {loss_value}");


Loading…
Cancel
Save