| @@ -16,13 +16,13 @@ namespace TensorFlowNET.Examples | |||||
| public int Priority => 10; | public int Priority => 10; | ||||
| public bool Enabled { get; set; } = true; | public bool Enabled { get; set; } = true; | ||||
| public string Name => "NN XOR"; | public string Name => "NN XOR"; | ||||
| public bool ImportGraph { get; set; } = true; | |||||
| public bool ImportGraph { get; set; } = false; | |||||
| public int num_steps = 5000; | public int num_steps = 5000; | ||||
| private NDArray data; | private NDArray data; | ||||
| private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8) | |||||
| private (Operation, Tensor, Tensor) make_graph(Tensor features,Tensor labels, int num_hidden = 8) | |||||
| { | { | ||||
| var stddev = 1 / Math.Sqrt(2); | var stddev = 1 / Math.Sqrt(2); | ||||
| var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev )); | var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev )); | ||||
| @@ -69,6 +69,8 @@ namespace TensorFlowNET.Examples | |||||
| Tensor features = graph.get_operation_by_name("Placeholder"); | Tensor features = graph.get_operation_by_name("Placeholder"); | ||||
| Tensor labels = graph.get_operation_by_name("Placeholder_1"); | Tensor labels = graph.get_operation_by_name("Placeholder_1"); | ||||
| Tensor loss = graph.get_operation_by_name("loss"); | Tensor loss = graph.get_operation_by_name("loss"); | ||||
| Tensor train_op = graph.get_operation_by_name("train_op"); | |||||
| Tensor global_step = graph.get_operation_by_name("global_step"); | |||||
| var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
| float loss_value = 0; | float loss_value = 0; | ||||
| @@ -86,7 +88,8 @@ namespace TensorFlowNET.Examples | |||||
| // [train_op, gs, loss], | // [train_op, gs, loss], | ||||
| // feed_dict={features: xy, labels: y_} | // feed_dict={features: xy, labels: y_} | ||||
| // ) | // ) | ||||
| loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_)); | |||||
| var result = sess.run(new ITensorOrOperation[] { train_op, global_step, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); | |||||
| loss_value = result[2]; | |||||
| step++; | step++; | ||||
| if (step % 1000 == 0) | if (step % 1000 == 0) | ||||
| Console.WriteLine($"Step {step} loss: {loss_value}"); | Console.WriteLine($"Step {step} loss: {loss_value}"); | ||||
| @@ -122,7 +125,9 @@ namespace TensorFlowNET.Examples | |||||
| // [train_op, gs, loss], | // [train_op, gs, loss], | ||||
| // feed_dict={features: xy, labels: y_} | // feed_dict={features: xy, labels: y_} | ||||
| // ) | // ) | ||||
| loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_)); | |||||
| var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); | |||||
| loss_value = result[2]; | |||||
| //step = result[1]; | |||||
| step++; | step++; | ||||
| if (step % 1000 == 0) | if (step % 1000 == 0) | ||||
| Console.WriteLine($"Step {step} loss: {loss_value}"); | Console.WriteLine($"Step {step} loss: {loss_value}"); | ||||